diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2020-09-29 02:04:11 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-09-29 02:05:50 -0700 |
commit | 5075d0342f51b3e44ae47fc0901a59a4d762c638 (patch) | |
tree | 0aea148dedd6654f16bb83bc9fe3bf12e1795333 /pkg/tcpip | |
parent | 48915bdedb346432327986570f28181d48b68567 (diff) |
Trim Network/Transport Endpoint/Protocol
* Remove Capabilities and NICID methods from NetworkEndpoint.
* Remove linkEP and stack parameters from NetworkProtocol.NewEndpoint.
The LinkEndpoint can be fetched from the NetworkInterface. The stack
is passed to the NetworkProtocol when it is created so the
NetworkEndpoint can get it from its protocol.
* Remove stack parameter from TransportProtocol.NewEndpoint.
Like the NetworkProtocol/Endpoint, the stack is passed to the
TransportProtocol when it is created.
PiperOrigin-RevId: 334332721
Diffstat (limited to 'pkg/tcpip')
23 files changed, 244 insertions, 248 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4bb7a417c..b47a7be51 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -96,14 +96,6 @@ func (e *endpoint) MTU() uint32 { return lmtu - uint32(e.MaxHeaderLength()) } -func (e *endpoint) NICID() tcpip.NICID { - return e.nic.ID() -} - -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.ARPSize } @@ -145,15 +137,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { localAddr := tcpip.Address(h.ProtocolAddressTarget()) if e.nud == nil { - if e.linkAddrCache.CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } @@ -179,7 +171,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) return } @@ -211,11 +203,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ protocol: p, nic: nic, - linkEP: sender, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 66450f896..56a56362e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -252,6 +252,8 @@ func buildDummyStack(t *testing.T) *stack.Stack { var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { + tester testObject + mu struct { sync.RWMutex disabled bool @@ -282,6 +284,10 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } +func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { + return &t.tester +} + func TestEnableWhenNICDisabled(t *testing.T) { tests := []struct { name string @@ -312,7 +318,7 @@ func TestEnableWhenNICDisabled(t *testing.T) { // We pass nil for all parameters except the NetworkInterface and Stack // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil, nil, nil, nil, s) + ep := p.NewEndpoint(&nic, nil, nil, nil) // The endpoint should initially be disabled, regardless the NIC's enabled // status. @@ -365,10 +371,15 @@ func TestEnableWhenNICDisabled(t *testing.T) { } func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil, &o, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() // Allocate and initialize the payload view. @@ -384,10 +395,10 @@ func TestIPv4Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIpv4Addr + nic.tester.dstAddr = remoteIpv4Addr + nic.tester.contents = payload r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) if err != nil { @@ -403,10 +414,15 @@ func TestIPv4Send(t *testing.T) { } func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -431,10 +447,10 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv4Addr + nic.tester.dstAddr = localIpv4Addr + nic.tester.contents = view[header.IPv4MinimumSize:totalLen] r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) if err != nil { @@ -447,8 +463,8 @@ func TestIPv4Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -478,10 +494,14 @@ func TestIPv4ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -528,26 +548,31 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv4Addr + nic.tester.dstAddr = localIpv4Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } } func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -590,10 +615,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv4Addr + nic.tester.dstAddr = localIpv4Addr + nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) if err != nil { @@ -608,8 +633,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + if nic.tester.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) } // Send second segment. @@ -620,16 +645,20 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } func TestIPv6Send(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, channel.New(0, 1280, ""), s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() if err := ep.Enable(); err != nil { @@ -649,10 +678,10 @@ func TestIPv6Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIpv6Addr + nic.tester.dstAddr = remoteIpv6Addr + nic.tester.contents = payload r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) if err != nil { @@ -668,10 +697,14 @@ func TestIPv6Send(t *testing.T) { } func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -695,10 +728,10 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv6Addr + nic.tester.dstAddr = localIpv6Addr + nic.tester.contents = view[header.IPv6MinimumSize:totalLen] r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) if err != nil { @@ -712,8 +745,8 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -752,10 +785,14 @@ func TestIPv6ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() if err := ep.Enable(); err != nil { @@ -814,19 +851,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIpv6Addr + nic.tester.dstAddr = localIpv6Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3e5cf2ad9..5c4f715d7 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -40,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. src := hdr.SourceAddress() - if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -110,7 +110,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { localAddr = "" } - r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 41f6914b9..7adf0fac3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -59,7 +59,6 @@ type endpoint struct { linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol - stack *stack.Stack // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -75,13 +74,12 @@ type endpoint struct { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: linkEP, + linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, - stack: st, } e.mu.addressableEndpointState.Init(e) return e @@ -173,16 +171,6 @@ func (e *endpoint) MTU() uint32 { return calculateMTU(e.linkEP.MTU()) } -// Capabilities implements stack.NetworkEndpoint. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nic.ID() -} - // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { @@ -324,8 +312,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.stack.FindNICNameFromID(e.NICID()) - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() @@ -341,7 +329,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) @@ -381,10 +369,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkt = pkt.Next() } - nicName := e.stack.FindNICNameFromID(e.NICID()) + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the @@ -404,7 +392,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -493,7 +481,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesInputDropped.Increment() @@ -677,6 +665,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) type protocol struct { + stack *stack.Stack + // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful. // @@ -799,7 +789,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui } // NewProtocol returns an IPv4 network protocol. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -810,6 +800,7 @@ func NewProtocol(*stack.Stack) stack.NetworkProtocol { hashIV := r[buckets] return &protocol{ + stack: s, ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 270439b5c..4b4b483cc 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -41,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match an address we own. src := hdr.SourceAddress() - if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -248,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if r.Stack().CheckLocalAddress(e.NICID(), ProtocolNumber, targetAddr) == 0 { + if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -283,7 +283,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } else if e.nud != nil { e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.NICID(), r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) } // ICMPv6 Neighbor Solicit messages are always sent to @@ -410,7 +410,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // address cache with the link address for the target of the message. if len(targetLinkAddr) != 0 { if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.NICID(), targetAddr, targetLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) return } @@ -438,7 +438,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme localAddr = "" } - r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index b4e8a077f..5472ceb46 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -123,6 +123,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -170,7 +174,7 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -312,7 +316,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 75b27a4cf..d1ad7acb7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -351,16 +351,6 @@ func (e *endpoint) MTU() uint32 { return calculateMTU(e.linkEP.MTU()) } -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nic.ID() -} - -// Capabilities implements stack.NetworkEndpoint. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { @@ -395,8 +385,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.stack.FindNICNameFromID(e.NICID()) - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() @@ -412,7 +402,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) return nil @@ -455,8 +445,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. - nicName := e.stack.FindNICNameFromID(e.NICID()) - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the @@ -476,7 +466,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -531,7 +521,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesInputDropped.Increment() @@ -1084,6 +1074,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) type protocol struct { + stack *stack.Stack + mu struct { sync.RWMutex @@ -1147,15 +1139,14 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: linkEP, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, dispatcher: dispatcher, protocol: p, - stack: st, } e.mu.addressableEndpointState.Init(e) e.mu.ndp = ndpState{ @@ -1312,8 +1303,9 @@ type Options struct { func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { opts.NDPConfigs.validate() - return func(*stack.Stack) stack.NetworkProtocol { + return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ + stack: s, fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), ndpDisp: opts.NDPDisp, diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 3495a8b19..d85b5c00f 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -1901,7 +1901,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) { proto.mu.Lock() _, hasEP := proto.mu.eps[ep] diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 1b5c61b80..84c082852 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -627,7 +627,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should be marked as tentative since we are starting DAD. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.NICID())) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } // Should not attempt to perform DAD on an address that is currently in the @@ -639,7 +639,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. // See endpoint.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.NICID())) + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID())) } remaining := ndp.configs.DupAddrDetectTransmits @@ -649,7 +649,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // Consider DAD to have resolved even if no DAD messages were actually // transmitted. if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, true, nil) + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } return nil @@ -661,7 +661,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // cannot be done while holding the IPv6 endpoint's lock. This is effectively // the same as starting a goroutine but we use a timer that fires immediately // so we can reset it for the next DAD iteration. - timer = ndp.ep.stack.Clock().AfterFunc(0, func() { + timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { ndp.ep.mu.Lock() defer ndp.ep.mu.Unlock() @@ -676,7 +676,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should still be marked as tentative since we are still // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.NICID())) + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } dadDone := remaining == 0 @@ -721,7 +721,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE delete(ndp.dad, addr) if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, dadDone, err) + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } // If DAD resolved for a stable SLAAC address, attempt generation of a @@ -750,7 +750,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } @@ -766,9 +766,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add return err } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.NICID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.NICID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) } icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) @@ -824,7 +824,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { // Let the integrator know DAD did not resolve. if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, false, nil) + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } @@ -861,7 +861,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { if ndp.dhcpv6Configuration != configuration { ndp.dhcpv6Configuration = configuration - ndpDisp.OnDHCPv6Configuration(ndp.ep.NICID(), configuration) + ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration) } } @@ -908,7 +908,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { } addrs, _ := opt.Addresses() - ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.NICID(), addrs, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: if ndp.ep.protocol.ndpDisp == nil { @@ -916,7 +916,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { } domainNames, _ := opt.DomainNames() - ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.NICID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -965,7 +965,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // Let the integrator know a discovered default router is invalidated. if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.ep.NICID(), ip) + ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } @@ -982,14 +982,14 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.NICID(), ip) { + if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { // Informed by the integrator to not remember the router, do // nothing further. return } state := defaultRouterState{ - invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateDefaultRouter(ip) }), } @@ -1012,14 +1012,14 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.NICID(), prefix) { + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { // Informed by the integrator to not remember the prefix, do // nothing further. return } state := onLinkPrefixState{ - invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } @@ -1048,7 +1048,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { // Let the integrator know a discovered on-link prefix is invalidated. if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.NICID(), prefix) + ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1164,7 +1164,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) @@ -1172,7 +1172,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint) }), - invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1230,7 +1230,7 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - if !ndpDisp.OnAutoGenAddress(ndp.ep.NICID(), addr) { + if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { // Informed by the integrator not to add the address. return nil } @@ -1276,7 +1276,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, - oIID.NICNameFromID(ndp.ep.NICID(), ndp.ep.nic.Name()), + oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()), dadCounter, oIID.SecretKey, ) @@ -1433,7 +1433,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } state := tempSLAACAddrState{ - deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1446,7 +1446,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) }), - invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1459,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() { + regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1677,7 +1677,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint addressEndpoint.SetDeprecated(true) if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.NICID(), addressEndpoint.AddressWithPrefix()) + ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } @@ -1702,7 +1702,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr) + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } prefix := addr.Subnet() @@ -1762,7 +1762,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr) + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } if !invalidateAddr { @@ -1878,7 +1878,7 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.ep.stack.Clock().AfterFunc(delay, func() { + ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { ndp.ep.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another @@ -1904,7 +1904,7 @@ func (ndp *ndpState) startSolicitingRouters() { ndp.ep.mu.Unlock() localAddr := addressEndpoint.AddressWithPrefix().Address - r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) addressEndpoint.DecRef() if err != nil { return @@ -1923,9 +1923,9 @@ func (ndp *ndpState) startSolicitingRouters() { return } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1961,7 +1961,7 @@ func (ndp *ndpState) startSolicitingRouters() { }, pkt, ); err != nil { sent.Dropped.Increment() - log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.NICID(), err) + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. remaining = 0 } else { @@ -2005,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() { // initializeTempAddrState initializes state related to temporary SLAAC // addresses. func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.NICID()) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 1947468fd..25464a03a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -66,10 +66,11 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) if err := ep.Enable(); err != nil { t.Fatalf("ep.Enable(): %s", err) } + t.Cleanup(ep.Close) return s, ep } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 572a2c3b6..4e4b00a92 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,6 @@ func (f *fwdTestNetworkEndpoint) MTU() uint32 { return f.ep.MTU() - uint32(f.MaxHeaderLength()) } -func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID -} - func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } @@ -91,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr return 0 } -func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -165,12 +157,12 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, + ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index dcc8363b2..a265fff0a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -235,7 +235,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e }, } nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil, nil, nil), + header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), } rng := rand.New(rand.NewSource(time.Now().UnixNano())) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 926ce9cfc..212c6edae 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -124,7 +124,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = nil - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic, ep, stack) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } nic.linkEP.Attach(nic) @@ -796,7 +796,7 @@ func (n *NIC) Name() string { return n.name } -// LinkEndpoint returns the link endpoint of n. +// LinkEndpoint implements NetworkInterface. func (n *NIC) LinkEndpoint() LinkEndpoint { return n.linkEP } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index df516aad7..fdd49b77f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -60,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 { return e.linkEP.MTU() - header.IPv6MinimumSize } -// Capabilities implements NetworkEndpoint.Capabilities. -func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize @@ -88,11 +83,6 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip return tcpip.ErrNotSupported } -// NICID implements NetworkEndpoint.NICID. -func (e *testIPv6Endpoint) NICID() tcpip.NICID { - return e.nicID -} - // HandlePacket implements NetworkEndpoint.HandlePacket. func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { } @@ -142,10 +132,10 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ nicID: nic.ID(), - linkEP: linkEP, + linkEP: nic.LinkEndpoint(), protocol: p, } e.AddressableEndpointState.Init(e) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ef42fd6e1..567e1904e 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -154,10 +154,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -485,6 +485,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // LinkEndpoint returns the link endpoint backing the interface. + LinkEndpoint() LinkEndpoint } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -515,10 +518,6 @@ type NetworkEndpoint interface { // minus the network endpoint max header length. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // underlying link-layer endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the network (and lower // level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -539,9 +538,6 @@ type NetworkEndpoint interface { // header to the given destination address. It takes ownership of pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error - // NICID returns the id of the NIC this endpoint belongs to. - NICID() tcpip.NICID - // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. // @@ -586,7 +582,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint + NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1b008a067..5ade3c832 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -72,17 +72,18 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } + linkEP := nic.LinkEndpoint() r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: nic.linkEP.LinkAddress(), + LocalLinkAddress: linkEP.LinkAddress(), RemoteAddress: remoteAddr, addressEndpoint: addressEndpoint, nic: nic, Loop: loop, } - if nic := r.nic; nic.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes r.linkCache = nic.stack @@ -94,7 +95,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.addressEndpoint.NetworkEndpoint().NICID() + return r.nic.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. @@ -115,7 +116,7 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.addressEndpoint.NetworkEndpoint().Capabilities() + return r.nic.LinkEndpoint().Capabilities() } // GSOMaxSize returns the maximum GSO packet size. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b740aa305..57d8e79e0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -837,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewEndpoint(s, network, waiterQueue) + return t.proto.NewEndpoint(network, waiterQueue) } // NewRawEndpoint creates a new raw transport layer endpoint of the given @@ -857,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewRawEndpoint(s, network, waiterQueue) + return t.proto.NewRawEndpoint(network, waiterQueue) } // NewPacketEndpoint creates a new packet endpoint listening for the given diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7589306a4..fda22c550 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -91,10 +91,6 @@ func (f *fakeNetworkEndpoint) MTU() uint32 { return f.ep.MTU() - uint32(f.MaxHeaderLength()) } -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID -} - func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } @@ -131,10 +127,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto return 0 } -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -207,12 +199,12 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, + ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8aae60740..62ab6d92f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -39,7 +39,7 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo - stack *stack.Stack + proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route @@ -59,8 +59,8 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Abort() { @@ -143,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) + r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return tcpip.ErrNoRoute } @@ -151,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -190,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( + if err := f.proto.stack.RegisterTransportEndpoint( a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, @@ -218,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - stack: f.stack, TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -262,6 +261,8 @@ type fakeTransportProtocolOptions struct { // fakeTransportProtocol is a transport-layer protocol descriptor. It // aggregates the number of packets received via endpoints of this protocol. type fakeTransportProtocol struct { + stack *stack.Stack + packetCount int controlCount int opts fakeTransportProtocolOptions @@ -271,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil } -func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -326,8 +327,8 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { return ok } -func fakeTransFactory(*stack.Stack) stack.TransportProtocol { - return &fakeTransportProtocol{} +func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { + return &fakeTransportProtocol{stack: s} } func TestTransportReceive(t *testing.T) { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 7484f4ad9..87d510f96 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -37,6 +37,8 @@ const ( // protocol implements stack.TransportProtocol. type protocol struct { + stack *stack.Stack + number tcpip.TransportProtocolNumber } @@ -57,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue) + return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) + return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } // MinimumPacketSize returns the minimum valid icmp packet size. @@ -130,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4(*stack.Stack) stack.TransportProtocol { - return &protocol{ProtocolNumber4} +func NewProtocol4(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6(*stack.Stack) stack.TransportProtocol { - return &protocol{ProtocolNumber6} +func NewProtocol6(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber6} } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 6a3c2c32b..5bce73605 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -133,6 +133,8 @@ func (s *synRcvdCounter) Threshold() uint64 { } type protocol struct { + stack *stack.Stack + mu sync.RWMutex sackEnabled bool recovery tcpip.TCPRecovery @@ -159,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid tcp packet size. @@ -505,8 +507,9 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a TCP transport protocol. -func NewProtocol(*stack.Stack) stack.TransportProtocol { +func NewProtocol(s *stack.Stack) stack.TransportProtocol { p := protocol{ + stack: s, sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index e6fc23258..da5b1deb2 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -45,6 +45,7 @@ const ( ) type protocol struct { + stack *stack.Stack } // Number returns the udp protocol number. @@ -53,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid udp packet size. @@ -114,6 +115,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a UDP transport protocol. -func NewProtocol(*stack.Stack) stack.TransportProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 22a809efa..7aaedb708 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1486,6 +1486,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1509,10 +1513,7 @@ func TestTTL(t *testing.T) { } else { p = ipv6.NewProtocol(nil) } - ep := p.NewEndpoint(&testInterface{}, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - })) + ep := p.NewEndpoint(&testInterface{}, nil, nil, nil) wantTTL = ep.DefaultTTL() ep.Close() } |