diff options
30 files changed, 269 insertions, 276 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index c975ad9cf..12b061def 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -61,8 +61,8 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cb9225bd7..b025bb087 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -15,14 +15,6 @@ // Package arp implements the ARP network protocol. It is used to resolve // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. -// -// To use it in the networking stack, pass arp.NewProtocol() as one of the -// network protocols when calling stack.New. Then add an "arp" address to every -// NIC on the stack that should respond to ARP requests. That is: -// -// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { -// // handle err -// } package arp import ( @@ -239,6 +231,10 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } // NewProtocol returns an ARP network protocol. -func NewProtocol() stack.NetworkProtocol { +// +// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" +// address must be added to every NIC that should respond to ARP requests. See +// ProtocolAddress for more details. +func NewProtocol(*stack.Stack) stack.NetworkProtocol { return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 9c9a859e3..626af975a 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -176,8 +176,8 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, NUDConfigs: c, NUDDisp: &d, UseNeighborCache: useNeighborCache, @@ -442,7 +442,7 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - p := arp.NewProtocol() + p := arp.NewProtocol(nil) linkRes, ok := p.(stack.LinkAddressResolver) if !ok { t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 477635f97..4640ca95c 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -195,8 +195,8 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) s.AddAddress(nicID, ipv4.ProtocolNumber, local) @@ -211,8 +211,8 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) s.AddAddress(nicID, ipv6.ProtocolNumber, local) @@ -229,8 +229,8 @@ func buildDummyStack(t *testing.T) *stack.Stack { t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) e := channel.New(0, 1280, "") if err := s.CreateNIC(nicID, e); err != nil { @@ -250,8 +250,9 @@ func buildDummyStack(t *testing.T) *stack.Stack { func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, s) defer ep.Close() // Allocate and initialize the payload view. @@ -287,8 +288,9 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() totalLen := header.IPv4MinimumSize + 30 @@ -357,8 +359,9 @@ func TestIPv4ReceiveControl(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize @@ -418,8 +421,9 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() totalLen := header.IPv4MinimumSize + 24 @@ -495,8 +499,9 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), s) defer ep.Close() // Allocate and initialize the payload view. @@ -532,8 +537,9 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() totalLen := header.IPv6MinimumSize + 30 @@ -611,8 +617,9 @@ func TestIPv6ReceiveControl(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e589d923d..254d66147 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv4 contains the implementation of the ipv4 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv4.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv4 contains the implementation of the ipv4 network protocol. package ipv4 import ( @@ -584,7 +579,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui } // NewProtocol returns an IPv4 network protocol. -func NewProtocol() stack.NetworkProtocol { +func NewProtocol(*stack.Stack) stack.NetworkProtocol { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 86187aba8..0b3ed9483 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,8 +36,8 @@ import ( func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) const defaultMTU = 65536 @@ -517,8 +517,8 @@ func TestInvalidFragments(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, }, }) e := channel.New(0, 1500, linkAddr) @@ -929,8 +929,8 @@ func TestReceiveFragments(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Setup a stack and endpoint. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -1140,7 +1140,7 @@ func TestWriteStats(t *testing.T) { func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC(1, _) failed: %s", err) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index f82e34209..dd58022d6 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -121,8 +121,8 @@ func TestICMPCounts(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) { @@ -259,8 +259,8 @@ func TestICMPCounts(t *testing.T) { func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) { @@ -424,12 +424,12 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), } @@ -724,7 +724,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: test.useNeighborCache, }) if isRouter { @@ -920,7 +920,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -1098,7 +1098,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -1204,7 +1204,7 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - p := NewProtocol() + p := NewProtocol(nil) linkRes, ok := p.(stack.LinkAddressResolver) if !ok { t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 3d070ddea..e436c6a9e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv6 contains the implementation of the ipv6 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv6.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv6 contains the implementation of the ipv6 network protocol. package ipv6 import ( @@ -617,7 +612,7 @@ func calculateMTU(mtu uint32) uint32 { } // NewProtocol returns an IPv6 network protocol. -func NewProtocol() stack.NetworkProtocol { +func NewProtocol(*stack.Stack) stack.NetworkProtocol { return &protocol{ defaultTTL: DefaultTTL, fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 7d138dadb..8ae146c5e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -141,18 +141,18 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -174,11 +174,11 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } snmc := header.SolicitedNodeAddr(addr2) @@ -186,8 +186,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -273,7 +273,7 @@ func TestAddIpv6Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -579,8 +579,8 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -1549,8 +1549,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -1668,8 +1668,8 @@ func TestInvalidIPv6Fragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, }, }) e := channel.New(0, 1500, linkAddr1) @@ -1847,7 +1847,7 @@ func TestWriteStats(t *testing.T) { func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC(1, _) failed: %s", err) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 7434df4a1..c93d1194f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -35,8 +35,8 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: useNeighborCache, }) @@ -98,7 +98,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -202,7 +202,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) @@ -475,7 +475,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: stackTyp.useNeighborCache, }) e := channel.New(1, 1280, nicLinkAddr) @@ -596,7 +596,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -707,7 +707,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) @@ -1172,7 +1172,7 @@ func TestRouterAdvertValidation(t *testing.T) { e := channel.New(10, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: stackTyp.useNeighborCache, }) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 91fc26722..51d428049 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -127,8 +127,8 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 3f58a15ea..8e0ee1cd7 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -112,8 +112,8 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 38c5bac71..8d18f3c8c 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -307,7 +307,7 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocol{proto}, + NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, UseNeighborCache: useNeighborCache, }) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5e43a9b0b..8416dbcdb 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -320,7 +320,7 @@ func TestDADDisabled(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, } @@ -414,7 +414,7 @@ func TestDADResolve(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = test.retransTimer @@ -638,7 +638,7 @@ func TestDADFail(t *testing.T) { } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, } @@ -759,7 +759,7 @@ func TestDADStop(t *testing.T) { DupAddrDetectTransmits: 2, } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, NDPConfigs: ndpConfigs, } @@ -819,7 +819,7 @@ func TestDADStop(t *testing.T) { // we attempt to update NDP configurations using an invalid NICID. func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, }) // No NIC with ID 1 yet. @@ -863,7 +863,7 @@ func TestSetNDPConfigurations(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, }) @@ -1113,7 +1113,7 @@ func TestNoRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: handle, DiscoverDefaultRouters: discover, @@ -1151,7 +1151,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -1192,7 +1192,7 @@ func TestRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -1293,7 +1293,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -1358,7 +1358,7 @@ func TestNoPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: handle, DiscoverOnLinkPrefixes: discover, @@ -1399,7 +1399,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: false, @@ -1445,7 +1445,7 @@ func TestPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverOnLinkPrefixes: true, @@ -1545,7 +1545,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverOnLinkPrefixes: true, @@ -1629,7 +1629,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: false, @@ -1716,7 +1716,7 @@ func TestNoAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: handle, AutoGenGlobalAddresses: autogen, @@ -1766,7 +1766,7 @@ func TestAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -1931,7 +1931,7 @@ func TestAutoGenTempAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: test.dupAddrTransmits, RetransmitTimer: test.retransmitTimer, @@ -2160,7 +2160,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ AutoGenTempGlobalAddresses: true, }, @@ -2228,7 +2228,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, @@ -2324,7 +2324,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, }) @@ -2469,7 +2469,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, }) @@ -2662,8 +2662,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { AutoGenAddressConflictRetries: 1, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ @@ -2794,8 +2794,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3307,7 +3307,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3449,7 +3449,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { } e := channel.New(10, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3515,7 +3515,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3700,7 +3700,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3781,7 +3781,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -4029,7 +4029,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, @@ -4165,7 +4165,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, NDPConfigs: addrType.ndpConfigs, NDPDisp: &ndpDisp, @@ -4250,7 +4250,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, @@ -4459,7 +4459,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, }, @@ -4509,7 +4509,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, }, @@ -4694,7 +4694,7 @@ func TestCleanupNDPState(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: true, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, @@ -4967,7 +4967,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, }, @@ -5217,7 +5217,7 @@ func TestRouterSolicitation(t *testing.T) { } } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ MaxRtrSolicitations: test.maxRtrSolicit, RtrSolicitationInterval: test.rtrSolicitInt, @@ -5357,7 +5357,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.NDPRS()) } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index dd6474297..bc9c9881a 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -241,6 +241,10 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd return "", false } +func newTestIPv6Protocol(*Stack) NetworkProtocol { + return &testIPv6Protocol{} +} + // Test the race condition where a NIC is removed and an RS timer fires at the // same time. func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { @@ -252,7 +256,7 @@ func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { e := testLinkEndpoint{} s := New(Options{ - NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NetworkProtocols: []NetworkProtocolFactory{newTestIPv6Protocol}, NDPConfigs: NDPConfigurations{ MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: minimumRtrSolicitationInterval, diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 2b97e5972..8cffb9fc6 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -60,7 +60,7 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, UseNeighborCache: true, }) @@ -137,7 +137,7 @@ func TestDefaultNUDConfigurations(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: stack.DefaultNUDConfigurations(), UseNeighborCache: true, }) @@ -192,7 +192,7 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -249,7 +249,7 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -329,7 +329,7 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -391,7 +391,7 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -443,7 +443,7 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -495,7 +495,7 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -547,7 +547,7 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -599,7 +599,7 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e7b7e95d4..c22633f6b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -517,13 +517,25 @@ type UniqueID interface { UniqueID() uint64 } +// NetworkProtocolFactory instantiates a network protocol. +// +// NetworkProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type NetworkProtocolFactory func(*Stack) NetworkProtocol + +// TransportProtocolFactory instantiates a transport protocol. +// +// TransportProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type TransportProtocolFactory func(*Stack) TransportProtocol + // Options contains optional Stack configuration. type Options struct { // NetworkProtocols lists the network protocols to enable. - NetworkProtocols []NetworkProtocol + NetworkProtocols []NetworkProtocolFactory // TransportProtocols lists the transport protocols to enable. - TransportProtocols []TransportProtocol + TransportProtocols []TransportProtocolFactory // Clock is an optional clock source used for timestampping packets. // @@ -755,7 +767,8 @@ func New(opts Options) *Stack { s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool) // Add specified network protocols. - for _, netProto := range opts.NetworkProtocols { + for _, netProtoFactory := range opts.NetworkProtocols { + netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -763,7 +776,8 @@ func New(opts Options) *Stack { } // Add specified transport protocols. - for _, transProto := range opts.TransportProtocols { + for _, transProtoFactory := range opts.TransportProtocols { + transProto := transProtoFactory(s) s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f92850053..c205650fe 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -231,7 +231,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -func fakeNetFactory() stack.NetworkProtocol { +func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -268,7 +268,7 @@ func TestNetworkReceive(t *testing.T) { // addresses attached to it: 1 & 2. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -428,7 +428,7 @@ func TestNetworkSend(t *testing.T) { // existing nic. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) @@ -455,7 +455,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -555,7 +555,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -574,7 +574,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { func TestDisableUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { @@ -586,7 +586,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := loopback.New() @@ -633,7 +633,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { @@ -645,7 +645,7 @@ func TestRemoveNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -706,7 +706,7 @@ func TestRouteWithDownNIC(t *testing.T) { setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(1, defaultMTU, "") @@ -872,7 +872,7 @@ func TestRoutes(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -952,7 +952,7 @@ func TestAddressRemoval(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -999,7 +999,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1090,7 +1090,7 @@ func TestEndpointExpiration(t *testing.T) { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1248,7 +1248,7 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1300,7 +1300,7 @@ func TestSpoofingWithAddress(t *testing.T) { dstAddr := tcpip.Address("\x03") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1366,7 +1366,7 @@ func TestSpoofingNoAddress(t *testing.T) { dstAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1429,7 +1429,7 @@ func verifyRoute(gotRoute, wantRoute stack.Route) error { func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1472,7 +1472,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // Create a new stack with two NICs. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1573,7 +1573,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1630,8 +1630,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { func TestNetworkOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{}, }) opt := tcpip.DefaultTTLOption(5) @@ -1657,7 +1657,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1724,7 +1724,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1809,7 +1809,7 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1836,7 +1836,7 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1870,7 +1870,7 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1901,7 +1901,7 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2022,7 +2022,7 @@ func TestCreateNICWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2089,7 +2089,7 @@ func TestNICForwarding(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) s.SetForwarding(fakeNetNumber, true) @@ -2336,7 +2336,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: test.autoGen, NDPDisp: &ndpDisp, OpaqueIIDOpts: test.iidOpts, @@ -2430,7 +2430,7 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: true, OpaqueIIDOpts: test.opaqueIIDOpts, } @@ -2463,7 +2463,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, AutoGenIPv6LinkLocal: true, NDPDisp: &ndpDisp, @@ -2522,7 +2522,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { for _, ps := range pebs { t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2813,8 +2813,8 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -2869,7 +2869,7 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { @@ -2921,7 +2921,7 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -2982,7 +2982,7 @@ func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { @@ -3059,7 +3059,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, @@ -3423,7 +3423,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, }) ep := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, ep); err != nil { @@ -3461,7 +3461,7 @@ func TestResolveWith(t *testing.T) { ) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, }) ep := channel.New(0, defaultMTU, "") ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4d6d62eec..698c8609e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -51,8 +51,8 @@ type testContext struct { // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { @@ -182,8 +182,8 @@ func TestTransportDemuxerRegister(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index cbb34d224..8aae60740 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -326,15 +326,15 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { return ok } -func fakeTransFactory() stack.TransportProtocol { +func fakeTransFactory(*stack.Stack) stack.TransportProtocol { return &fakeTransportProtocol{} } func TestTransportReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -404,8 +404,8 @@ func TestTransportReceive(t *testing.T) { func TestTransportControlReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -481,8 +481,8 @@ func TestTransportControlReceive(t *testing.T) { func TestTransportSend(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -527,8 +527,8 @@ func TestTransportSend(t *testing.T) { func TestTransportOptions(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) v := tcpip.TCPModerateReceiveBufferOption(true) @@ -546,8 +546,8 @@ func TestTransportOptions(t *testing.T) { func TestTransportForwarding(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) s.SetForwarding(fakeNetNumber, true) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index fecbe7ba7..f35dcc084 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -120,8 +120,8 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -203,7 +203,7 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { otherAddr := tcpip.Address(addrBytes) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 659acbc7a..72d86b5ab 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -140,11 +140,9 @@ func TestPingMulticastBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ipv4Proto := ipv4.NewProtocol() - ipv6Proto := ipv6.NewProtocol() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, }) // We only expect a single packet in response to our ICMP Echo Request. e := channel.New(1, defaultMTU, "") @@ -176,18 +174,18 @@ func TestPingMulticastBroadcast(t *testing.T) { var rxICMP func(*channel.Endpoint, tcpip.Address) var expectedSrc tcpip.Address var expectedDst tcpip.Address - var proto stack.NetworkProtocol + var protoNum tcpip.NetworkProtocolNumber switch l := len(test.dstAddr); l { case header.IPv4AddressSize: rxICMP = rxIPv4ICMP expectedSrc = ipv4Addr.Address expectedDst = remoteIPv4Addr - proto = ipv4Proto + protoNum = header.IPv4ProtocolNumber case header.IPv6AddressSize: rxICMP = rxIPv6ICMP expectedSrc = ipv6Addr.Address expectedDst = remoteIPv6Addr - proto = ipv6Proto + protoNum = header.IPv6ProtocolNumber default: t.Fatalf("got unexpected address length = %d bytes", l) } @@ -205,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -380,8 +378,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID, e); err != nil { @@ -466,8 +464,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 941c3c08d..7484f4ad9 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -13,12 +13,7 @@ // limitations under the License. // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport -// protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and activated on the stack by passing -// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport -// protocols when calling stack.New(). Then endpoints can be created by passing -// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number -// when calling Stack.NewEndpoint(). +// protocols for use in ping. package icmp import ( @@ -135,11 +130,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4() stack.TransportProtocol { +func NewProtocol4(*stack.Stack) stack.TransportProtocol { return &protocol{ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6() stack.TransportProtocol { +func NewProtocol6(*stack.Stack) stack.TransportProtocol { return &protocol{ProtocolNumber6} } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 371067048..6a3c2c32b 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package tcp contains the implementation of the TCP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing tcp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package tcp contains the implementation of the TCP transport protocol. package tcp import ( @@ -510,7 +505,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a TCP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(*stack.Stack) stack.TransportProtocol { p := protocol{ sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 8326736dc..dd810f594 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4300,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func TestDefaultBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4363,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) { func TestMinMaxBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4420,8 +4420,8 @@ func TestMinMaxBufferSizes(t *testing.T) { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -4470,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) { func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) id := loopback.New() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ebbae6e2f..faf51ef95 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -155,8 +155,8 @@ type Context struct { // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) const sendBufferSize = 1 << 20 // 1 MiB diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index a1d0f49d9..e6fc23258 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package udp contains the implementation of the UDP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing udp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package udp contains the implementation of the UDP transport protocol. package udp import ( @@ -119,6 +114,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a UDP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(*stack.Stack) stack.TransportProtocol { return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 6813d57aa..0556ef879 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -294,8 +294,8 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) } @@ -532,8 +532,8 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -807,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: tt.handleLocal, }) defer c.cleanup() @@ -1485,13 +1485,13 @@ func TestTTL(t *testing.T) { } else { var p stack.NetworkProtocol if flow.isV4() { - p = ipv4.NewProtocol() + p = ipv4.NewProtocol(nil) } else { - p = ipv6.NewProtocol() + p = ipv6.NewProtocol(nil) } ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, })) wantTTL = ep.DefaultTTL() ep.Close() @@ -2345,9 +2345,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 4940ea96a..2e652ddad 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1059,8 +1059,8 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st } func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} + netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol} + transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4} s := netstack.Stack{stack.New(stack.Options{ NetworkProtocols: netProtos, TransportProtocols: transProtos, diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go index 6cabfb451..5afe10f69 100644 --- a/test/benchmarks/tcp/tcp_proxy.go +++ b/test/benchmarks/tcp/tcp_proxy.go @@ -174,8 +174,8 @@ func newNetstackImpl(mode string) (impl, error) { } // Create a new network stack. - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()} + netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol} + transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol} s := stack.New(stack.Options{ NetworkProtocols: netProtos, TransportProtocols: transProtos, |