From a5acc0616c9552c7252e3f133f9ad4648422cc5f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 28 Sep 2020 16:22:09 -0700 Subject: Support creating protocol instances with Stack ref Network or transport protocols may want to reach the stack. Support this by letting the stack create the protocol instances so it can pass a reference to itself at protocol creation time. Note, protocols do not yet use the stack in this CL but later CLs will make use of the stack from protocols. PiperOrigin-RevId: 334260210 --- pkg/tcpip/adapters/gonet/gonet_test.go | 4 +- pkg/tcpip/network/arp/arp.go | 14 ++-- pkg/tcpip/network/arp/arp_test.go | 6 +- pkg/tcpip/network/ip_test.go | 47 ++++++------ pkg/tcpip/network/ipv4/ipv4.go | 9 +-- pkg/tcpip/network/ipv4/ipv4_test.go | 14 ++-- pkg/tcpip/network/ipv6/icmp_test.go | 24 +++---- pkg/tcpip/network/ipv6/ipv6.go | 9 +-- pkg/tcpip/network/ipv6/ipv6_test.go | 36 +++++----- pkg/tcpip/network/ipv6/ndp_test.go | 16 ++--- pkg/tcpip/sample/tun_tcp_connect/main.go | 4 +- pkg/tcpip/sample/tun_tcp_echo/main.go | 4 +- pkg/tcpip/stack/forwarder_test.go | 2 +- pkg/tcpip/stack/ndp_test.go | 80 ++++++++++----------- pkg/tcpip/stack/nic_test.go | 6 +- pkg/tcpip/stack/nud_test.go | 20 +++--- pkg/tcpip/stack/stack.go | 22 ++++-- pkg/tcpip/stack/stack_test.go | 84 +++++++++++----------- pkg/tcpip/stack/transport_demuxer_test.go | 8 +-- pkg/tcpip/stack/transport_test.go | 22 +++--- pkg/tcpip/tests/integration/loopback_test.go | 6 +- .../tests/integration/multicast_broadcast_test.go | 22 +++--- pkg/tcpip/transport/icmp/protocol.go | 11 +-- pkg/tcpip/transport/tcp/protocol.go | 9 +-- pkg/tcpip/transport/tcp/tcp_test.go | 20 +++--- pkg/tcpip/transport/tcp/testing/context/context.go | 4 +- pkg/tcpip/transport/udp/protocol.go | 9 +-- pkg/tcpip/transport/udp/udp_test.go | 25 ++++--- runsc/boot/loader.go | 4 +- test/benchmarks/tcp/tcp_proxy.go | 4 +- 30 files changed, 269 insertions(+), 276 deletions(-) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index c975ad9cf..12b061def 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -61,8 +61,8 @@ func TestTimeouts(t *testing.T) { func newLoopbackStack() (*stack.Stack, *tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, }) if err := s.CreateNIC(NICID, loopback.New()); err != nil { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cb9225bd7..b025bb087 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -15,14 +15,6 @@ // Package arp implements the ARP network protocol. It is used to resolve // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. -// -// To use it in the networking stack, pass arp.NewProtocol() as one of the -// network protocols when calling stack.New. Then add an "arp" address to every -// NIC on the stack that should respond to ARP requests. That is: -// -// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { -// // handle err -// } package arp import ( @@ -239,6 +231,10 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } // NewProtocol returns an ARP network protocol. -func NewProtocol() stack.NetworkProtocol { +// +// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" +// address must be added to every NIC that should respond to ARP requests. See +// ProtocolAddress for more details. +func NewProtocol(*stack.Stack) stack.NetworkProtocol { return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 9c9a859e3..626af975a 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -176,8 +176,8 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, NUDConfigs: c, NUDDisp: &d, UseNeighborCache: useNeighborCache, @@ -442,7 +442,7 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - p := arp.NewProtocol() + p := arp.NewProtocol(nil) linkRes, ok := p.(stack.LinkAddressResolver) if !ok { t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 477635f97..4640ca95c 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -195,8 +195,8 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) s.AddAddress(nicID, ipv4.ProtocolNumber, local) @@ -211,8 +211,8 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) s.CreateNIC(nicID, loopback.New()) s.AddAddress(nicID, ipv6.ProtocolNumber, local) @@ -229,8 +229,8 @@ func buildDummyStack(t *testing.T) *stack.Stack { t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) e := channel.New(0, 1280, "") if err := s.CreateNIC(nicID, e); err != nil { @@ -250,8 +250,9 @@ func buildDummyStack(t *testing.T) *stack.Stack { func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, s) defer ep.Close() // Allocate and initialize the payload view. @@ -287,8 +288,9 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() totalLen := header.IPv4MinimumSize + 30 @@ -357,8 +359,9 @@ func TestIPv4ReceiveControl(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize @@ -418,8 +421,9 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() totalLen := header.IPv4MinimumSize + 24 @@ -495,8 +499,9 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), s) defer ep.Close() // Allocate and initialize the payload view. @@ -532,8 +537,9 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() totalLen := header.IPv6MinimumSize + 30 @@ -611,8 +617,9 @@ func TestIPv6ReceiveControl(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} - proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) defer ep.Close() dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e589d923d..254d66147 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv4 contains the implementation of the ipv4 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv4.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv4 contains the implementation of the ipv4 network protocol. package ipv4 import ( @@ -584,7 +579,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui } // NewProtocol returns an IPv4 network protocol. -func NewProtocol() stack.NetworkProtocol { +func NewProtocol(*stack.Stack) stack.NetworkProtocol { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 86187aba8..0b3ed9483 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,8 +36,8 @@ import ( func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) const defaultMTU = 65536 @@ -517,8 +517,8 @@ func TestInvalidFragments(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, }, }) e := channel.New(0, 1500, linkAddr) @@ -929,8 +929,8 @@ func TestReceiveFragments(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Setup a stack and endpoint. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -1140,7 +1140,7 @@ func TestWriteStats(t *testing.T) { func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC(1, _) failed: %s", err) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index f82e34209..dd58022d6 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -121,8 +121,8 @@ func TestICMPCounts(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) { @@ -259,8 +259,8 @@ func TestICMPCounts(t *testing.T) { func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) { @@ -424,12 +424,12 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), } @@ -724,7 +724,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: test.useNeighborCache, }) if isRouter { @@ -920,7 +920,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -1098,7 +1098,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -1204,7 +1204,7 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - p := NewProtocol() + p := NewProtocol(nil) linkRes, ok := p.(stack.LinkAddressResolver) if !ok { t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 3d070ddea..e436c6a9e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv6 contains the implementation of the ipv6 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv6.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv6 contains the implementation of the ipv6 network protocol. package ipv6 import ( @@ -617,7 +612,7 @@ func calculateMTU(mtu uint32) uint32 { } // NewProtocol returns an IPv6 network protocol. -func NewProtocol() stack.NetworkProtocol { +func NewProtocol(*stack.Stack) stack.NetworkProtocol { return &protocol{ defaultTTL: DefaultTTL, fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 7d138dadb..8ae146c5e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -141,18 +141,18 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -174,11 +174,11 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } snmc := header.SolicitedNodeAddr(addr2) @@ -186,8 +186,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -273,7 +273,7 @@ func TestAddIpv6Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -579,8 +579,8 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -1549,8 +1549,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { @@ -1668,8 +1668,8 @@ func TestInvalidIPv6Fragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, }, }) e := channel.New(0, 1500, linkAddr1) @@ -1847,7 +1847,7 @@ func TestWriteStats(t *testing.T) { func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC(1, _) failed: %s", err) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 7434df4a1..c93d1194f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -35,8 +35,8 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: useNeighborCache, }) @@ -98,7 +98,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -202,7 +202,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) @@ -475,7 +475,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: stackTyp.useNeighborCache, }) e := channel.New(1, 1280, nicLinkAddr) @@ -596,7 +596,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -707,7 +707,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) @@ -1172,7 +1172,7 @@ func TestRouterAdvertValidation(t *testing.T) { e := channel.New(10, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, UseNeighborCache: stackTyp.useNeighborCache, }) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 91fc26722..51d428049 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -127,8 +127,8 @@ func main() { // Create the stack with ipv4 and tcp protocols, then add a tun-based // NIC and ipv4 address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 3f58a15ea..8e0ee1cd7 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -112,8 +112,8 @@ func main() { // Create the stack with ip and tcp protocols, then add a tun-based // NIC and address. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) mtu, err := rawfile.GetMTU(tunName) diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 38c5bac71..8d18f3c8c 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -307,7 +307,7 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocol{proto}, + NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, UseNeighborCache: useNeighborCache, }) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5e43a9b0b..8416dbcdb 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -320,7 +320,7 @@ func TestDADDisabled(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, } @@ -414,7 +414,7 @@ func TestDADResolve(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, } opts.NDPConfigs.RetransmitTimer = test.retransTimer @@ -638,7 +638,7 @@ func TestDADFail(t *testing.T) { } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, } @@ -759,7 +759,7 @@ func TestDADStop(t *testing.T) { DupAddrDetectTransmits: 2, } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, NDPConfigs: ndpConfigs, } @@ -819,7 +819,7 @@ func TestDADStop(t *testing.T) { // we attempt to update NDP configurations using an invalid NICID. func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, }) // No NIC with ID 1 yet. @@ -863,7 +863,7 @@ func TestSetNDPConfigurations(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPDisp: &ndpDisp, }) @@ -1113,7 +1113,7 @@ func TestNoRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: handle, DiscoverDefaultRouters: discover, @@ -1151,7 +1151,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -1192,7 +1192,7 @@ func TestRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -1293,7 +1293,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -1358,7 +1358,7 @@ func TestNoPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: handle, DiscoverOnLinkPrefixes: discover, @@ -1399,7 +1399,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: false, @@ -1445,7 +1445,7 @@ func TestPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverOnLinkPrefixes: true, @@ -1545,7 +1545,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverOnLinkPrefixes: true, @@ -1629,7 +1629,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: false, @@ -1716,7 +1716,7 @@ func TestNoAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: handle, AutoGenGlobalAddresses: autogen, @@ -1766,7 +1766,7 @@ func TestAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -1931,7 +1931,7 @@ func TestAutoGenTempAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: test.dupAddrTransmits, RetransmitTimer: test.retransmitTimer, @@ -2160,7 +2160,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ AutoGenTempGlobalAddresses: true, }, @@ -2228,7 +2228,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, @@ -2324,7 +2324,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, }) @@ -2469,7 +2469,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, }) @@ -2662,8 +2662,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { AutoGenAddressConflictRetries: 1, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ @@ -2794,8 +2794,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3307,7 +3307,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3449,7 +3449,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { } e := channel.New(10, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3515,7 +3515,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3700,7 +3700,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -3781,7 +3781,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -4029,7 +4029,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, @@ -4165,7 +4165,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, NDPConfigs: addrType.ndpConfigs, NDPDisp: &ndpDisp, @@ -4250,7 +4250,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, @@ -4459,7 +4459,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, }, @@ -4509,7 +4509,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, }, @@ -4694,7 +4694,7 @@ func TestCleanupNDPState(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: true, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, @@ -4967,7 +4967,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, }, @@ -5217,7 +5217,7 @@ func TestRouterSolicitation(t *testing.T) { } } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ MaxRtrSolicitations: test.maxRtrSolicit, RtrSolicitationInterval: test.rtrSolicitInt, @@ -5357,7 +5357,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.NDPRS()) } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index dd6474297..bc9c9881a 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -241,6 +241,10 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd return "", false } +func newTestIPv6Protocol(*Stack) NetworkProtocol { + return &testIPv6Protocol{} +} + // Test the race condition where a NIC is removed and an RS timer fires at the // same time. func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { @@ -252,7 +256,7 @@ func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { e := testLinkEndpoint{} s := New(Options{ - NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NetworkProtocols: []NetworkProtocolFactory{newTestIPv6Protocol}, NDPConfigs: NDPConfigurations{ MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: minimumRtrSolicitationInterval, diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 2b97e5972..8cffb9fc6 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -60,7 +60,7 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, UseNeighborCache: true, }) @@ -137,7 +137,7 @@ func TestDefaultNUDConfigurations(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: stack.DefaultNUDConfigurations(), UseNeighborCache: true, }) @@ -192,7 +192,7 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -249,7 +249,7 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -329,7 +329,7 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -391,7 +391,7 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -443,7 +443,7 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -495,7 +495,7 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -547,7 +547,7 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) @@ -599,7 +599,7 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, UseNeighborCache: true, }) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e7b7e95d4..c22633f6b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -517,13 +517,25 @@ type UniqueID interface { UniqueID() uint64 } +// NetworkProtocolFactory instantiates a network protocol. +// +// NetworkProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type NetworkProtocolFactory func(*Stack) NetworkProtocol + +// TransportProtocolFactory instantiates a transport protocol. +// +// TransportProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type TransportProtocolFactory func(*Stack) TransportProtocol + // Options contains optional Stack configuration. type Options struct { // NetworkProtocols lists the network protocols to enable. - NetworkProtocols []NetworkProtocol + NetworkProtocols []NetworkProtocolFactory // TransportProtocols lists the transport protocols to enable. - TransportProtocols []TransportProtocol + TransportProtocols []TransportProtocolFactory // Clock is an optional clock source used for timestampping packets. // @@ -755,7 +767,8 @@ func New(opts Options) *Stack { s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool) // Add specified network protocols. - for _, netProto := range opts.NetworkProtocols { + for _, netProtoFactory := range opts.NetworkProtocols { + netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -763,7 +776,8 @@ func New(opts Options) *Stack { } // Add specified transport protocols. - for _, transProto := range opts.TransportProtocols { + for _, transProtoFactory := range opts.TransportProtocols { + transProto := transProtoFactory(s) s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f92850053..c205650fe 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -231,7 +231,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -func fakeNetFactory() stack.NetworkProtocol { +func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -268,7 +268,7 @@ func TestNetworkReceive(t *testing.T) { // addresses attached to it: 1 & 2. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -428,7 +428,7 @@ func TestNetworkSend(t *testing.T) { // existing nic. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) @@ -455,7 +455,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -555,7 +555,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -574,7 +574,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { func TestDisableUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { @@ -586,7 +586,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := loopback.New() @@ -633,7 +633,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { @@ -645,7 +645,7 @@ func TestRemoveNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -706,7 +706,7 @@ func TestRouteWithDownNIC(t *testing.T) { setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(1, defaultMTU, "") @@ -872,7 +872,7 @@ func TestRoutes(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -952,7 +952,7 @@ func TestAddressRemoval(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -999,7 +999,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1090,7 +1090,7 @@ func TestEndpointExpiration(t *testing.T) { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1248,7 +1248,7 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1300,7 +1300,7 @@ func TestSpoofingWithAddress(t *testing.T) { dstAddr := tcpip.Address("\x03") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1366,7 +1366,7 @@ func TestSpoofingNoAddress(t *testing.T) { dstAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1429,7 +1429,7 @@ func verifyRoute(gotRoute, wantRoute stack.Route) error { func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1472,7 +1472,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // Create a new stack with two NICs. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1573,7 +1573,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1630,8 +1630,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { func TestNetworkOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{}, }) opt := tcpip.DefaultTTLOption(5) @@ -1657,7 +1657,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1724,7 +1724,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1809,7 +1809,7 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1836,7 +1836,7 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1870,7 +1870,7 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -1901,7 +1901,7 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2022,7 +2022,7 @@ func TestCreateNICWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2089,7 +2089,7 @@ func TestNICForwarding(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) s.SetForwarding(fakeNetNumber, true) @@ -2336,7 +2336,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: test.autoGen, NDPDisp: &ndpDisp, OpaqueIIDOpts: test.iidOpts, @@ -2430,7 +2430,7 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, AutoGenIPv6LinkLocal: true, OpaqueIIDOpts: test.opaqueIIDOpts, } @@ -2463,7 +2463,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: ndpConfigs, AutoGenIPv6LinkLocal: true, NDPDisp: &ndpDisp, @@ -2522,7 +2522,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { for _, ps := range pebs { t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2813,8 +2813,8 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, @@ -2869,7 +2869,7 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { @@ -2921,7 +2921,7 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -2982,7 +2982,7 @@ func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { @@ -3059,7 +3059,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NDPConfigs: stack.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, @@ -3423,7 +3423,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, }) ep := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, ep); err != nil { @@ -3461,7 +3461,7 @@ func TestResolveWith(t *testing.T) { ) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, }) ep := channel.New(0, defaultMTU, "") ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4d6d62eec..698c8609e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -51,8 +51,8 @@ type testContext struct { // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { @@ -182,8 +182,8 @@ func TestTransportDemuxerRegister(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index cbb34d224..8aae60740 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -326,15 +326,15 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { return ok } -func fakeTransFactory() stack.TransportProtocol { +func fakeTransFactory(*stack.Stack) stack.TransportProtocol { return &fakeTransportProtocol{} } func TestTransportReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -404,8 +404,8 @@ func TestTransportReceive(t *testing.T) { func TestTransportControlReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -481,8 +481,8 @@ func TestTransportControlReceive(t *testing.T) { func TestTransportSend(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -527,8 +527,8 @@ func TestTransportSend(t *testing.T) { func TestTransportOptions(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) v := tcpip.TCPModerateReceiveBufferOption(true) @@ -546,8 +546,8 @@ func TestTransportOptions(t *testing.T) { func TestTransportForwarding(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) s.SetForwarding(fakeNetNumber, true) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index fecbe7ba7..f35dcc084 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -120,8 +120,8 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -203,7 +203,7 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { otherAddr := tcpip.Address(addrBytes) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 659acbc7a..72d86b5ab 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -140,11 +140,9 @@ func TestPingMulticastBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ipv4Proto := ipv4.NewProtocol() - ipv6Proto := ipv6.NewProtocol() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, }) // We only expect a single packet in response to our ICMP Echo Request. e := channel.New(1, defaultMTU, "") @@ -176,18 +174,18 @@ func TestPingMulticastBroadcast(t *testing.T) { var rxICMP func(*channel.Endpoint, tcpip.Address) var expectedSrc tcpip.Address var expectedDst tcpip.Address - var proto stack.NetworkProtocol + var protoNum tcpip.NetworkProtocolNumber switch l := len(test.dstAddr); l { case header.IPv4AddressSize: rxICMP = rxIPv4ICMP expectedSrc = ipv4Addr.Address expectedDst = remoteIPv4Addr - proto = ipv4Proto + protoNum = header.IPv4ProtocolNumber case header.IPv6AddressSize: rxICMP = rxIPv6ICMP expectedSrc = ipv6Addr.Address expectedDst = remoteIPv6Addr - proto = ipv6Proto + protoNum = header.IPv6ProtocolNumber default: t.Fatalf("got unexpected address length = %d bytes", l) } @@ -205,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -380,8 +378,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID, e); err != nil { @@ -466,8 +464,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, loopback.New()); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 941c3c08d..7484f4ad9 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -13,12 +13,7 @@ // limitations under the License. // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport -// protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and activated on the stack by passing -// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport -// protocols when calling stack.New(). Then endpoints can be created by passing -// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number -// when calling Stack.NewEndpoint(). +// protocols for use in ping. package icmp import ( @@ -135,11 +130,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4() stack.TransportProtocol { +func NewProtocol4(*stack.Stack) stack.TransportProtocol { return &protocol{ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6() stack.TransportProtocol { +func NewProtocol6(*stack.Stack) stack.TransportProtocol { return &protocol{ProtocolNumber6} } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 371067048..6a3c2c32b 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package tcp contains the implementation of the TCP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing tcp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package tcp contains the implementation of the TCP transport protocol. package tcp import ( @@ -510,7 +505,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a TCP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(*stack.Stack) stack.TransportProtocol { p := protocol{ sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 8326736dc..dd810f594 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4300,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func TestDefaultBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4363,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) { func TestMinMaxBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4420,8 +4420,8 @@ func TestMinMaxBufferSizes(t *testing.T) { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -4470,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) { func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) id := loopback.New() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ebbae6e2f..faf51ef95 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -155,8 +155,8 @@ type Context struct { // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) const sendBufferSize = 1 << 20 // 1 MiB diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index a1d0f49d9..e6fc23258 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package udp contains the implementation of the UDP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing udp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package udp contains the implementation of the UDP transport protocol. package udp import ( @@ -119,6 +114,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a UDP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(*stack.Stack) stack.TransportProtocol { return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 6813d57aa..0556ef879 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -294,8 +294,8 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) } @@ -532,8 +532,8 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -807,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: tt.handleLocal, }) defer c.cleanup() @@ -1485,13 +1485,13 @@ func TestTTL(t *testing.T) { } else { var p stack.NetworkProtocol if flow.isV4() { - p = ipv4.NewProtocol() + p = ipv4.NewProtocol(nil) } else { - p = ipv6.NewProtocol() + p = ipv6.NewProtocol(nil) } ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, })) wantTTL = ep.DefaultTTL() ep.Close() @@ -2345,9 +2345,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 4940ea96a..2e652ddad 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1059,8 +1059,8 @@ func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID st } func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} + netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol} + transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4} s := netstack.Stack{stack.New(stack.Options{ NetworkProtocols: netProtos, TransportProtocols: transProtos, diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go index 6cabfb451..5afe10f69 100644 --- a/test/benchmarks/tcp/tcp_proxy.go +++ b/test/benchmarks/tcp/tcp_proxy.go @@ -174,8 +174,8 @@ func newNetstackImpl(mode string) (impl, error) { } // Create a new network stack. - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()} + netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol} + transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol} s := stack.New(stack.Options{ NetworkProtocols: netProtos, TransportProtocols: transProtos, -- cgit v1.2.3