From 257703c050e5901aeb3734f200f5a6b41856b4d9 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 9 Oct 2020 12:07:02 -0700 Subject: Automated rollback of changelist 336304024 PiperOrigin-RevId: 336339194 --- pkg/tcpip/stack/BUILD | 5 +- pkg/tcpip/stack/forwarder.go | 131 ----- pkg/tcpip/stack/forwarder_test.go | 878 --------------------------------- pkg/tcpip/stack/forwarding_test.go | 876 ++++++++++++++++++++++++++++++++ pkg/tcpip/stack/neighbor_entry.go | 4 +- pkg/tcpip/stack/neighbor_entry_test.go | 5 +- pkg/tcpip/stack/nic.go | 125 +++-- pkg/tcpip/stack/nic_test.go | 10 +- pkg/tcpip/stack/pending_packets.go | 133 +++++ pkg/tcpip/stack/registration.go | 33 +- pkg/tcpip/stack/route.go | 50 +- pkg/tcpip/stack/stack.go | 42 +- pkg/tcpip/stack/stack_test.go | 62 +-- 13 files changed, 1153 insertions(+), 1201 deletions(-) delete mode 100644 pkg/tcpip/stack/forwarder.go delete mode 100644 pkg/tcpip/stack/forwarder_test.go create mode 100644 pkg/tcpip/stack/forwarding_test.go create mode 100644 pkg/tcpip/stack/pending_packets.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 2eaeab779..eba97334e 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,7 +56,6 @@ go_library( srcs = [ "addressable_endpoint_state.go", "conntrack.go", - "forwarder.go", "headertype_string.go", "icmp_rate_limit.go", "iptables.go", @@ -73,6 +72,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "pending_packets.go", "rand.go", "registration.go", "route.go", @@ -123,7 +123,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", @@ -139,7 +138,7 @@ go_test( name = "stack_test", size = "small", srcs = [ - "forwarder_test.go", + "forwarding_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go deleted file mode 100644 index 3eff141e6..000000000 --- a/pkg/tcpip/stack/forwarder.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -const ( - // maxPendingResolutions is the maximum number of pending link-address - // resolutions. - maxPendingResolutions = 64 - maxPendingPacketsPerResolution = 256 -) - -type pendingPacket struct { - nic *NIC - route *Route - proto tcpip.NetworkProtocolNumber - pkt *PacketBuffer -} - -type forwardQueue struct { - sync.Mutex - - // The packets to send once the resolver completes. - packets map[<-chan struct{}][]*pendingPacket - - // FIFO of channels used to cancel the oldest goroutine waiting for - // link-address resolution. - cancelChans []chan struct{} -} - -func newForwardQueue() *forwardQueue { - return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} -} - -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - shouldWait := false - - f.Lock() - packets, ok := f.packets[ch] - if !ok { - shouldWait = true - } - for len(packets) == maxPendingPacketsPerResolution { - p := packets[0] - packets = packets[1:] - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() - p.route.Release() - } - if l := len(packets); l >= maxPendingPacketsPerResolution { - panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) - } - f.packets[ch] = append(packets, &pendingPacket{ - nic: n, - route: r, - proto: protocol, - pkt: pkt, - }) - f.Unlock() - - if !shouldWait { - return - } - - // Wait for the link-address resolution to complete. - // Start a goroutine with a forwarding-cancel channel so that we can - // limit the maximum number of goroutines running concurrently. - cancel := f.newCancelChannel() - go func() { - cancelled := false - select { - case <-ch: - case <-cancel: - cancelled = true - } - - f.Lock() - packets := f.packets[ch] - delete(f.packets, ch) - f.Unlock() - - for _, p := range packets { - if cancelled { - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() - } else if _, err := p.route.Resolve(nil); err != nil { - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() - } else { - p.nic.forwardPacket(p.route, p.proto, p.pkt) - } - p.route.Release() - } - }() -} - -// newCancelChannel creates a channel that can cancel a pending forwarding -// activity. The oldest channel is closed if the number of open channels would -// exceed maxPendingResolutions. -func (f *forwardQueue) newCancelChannel() chan struct{} { - f.Lock() - defer f.Unlock() - - if len(f.cancelChans) == maxPendingResolutions { - ch := f.cancelChans[0] - f.cancelChans = f.cancelChans[1:] - close(ch) - } - if l := len(f.cancelChans); l >= maxPendingResolutions { - panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) - } - - ch := make(chan struct{}) - f.cancelChans = append(f.cancelChans, ch) - return ch -} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go deleted file mode 100644 index 4e4b00a92..000000000 --- a/pkg/tcpip/stack/forwarder_test.go +++ /dev/null @@ -1,878 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "encoding/binary" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 - fwdTestNetHeaderLen = 12 - fwdTestNetDefaultPrefixLen = 8 - - // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match - // the MTU of loopback interfaces on linux systems. - fwdTestNetDefaultMTU = 65536 - - dstAddrOffset = 0 - srcAddrOffset = 1 - protocolNumberOffset = 2 -) - -// fwdTestNetworkEndpoint is a network-layer protocol endpoint. -// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only -// use the first three: destination address, source address, and transport -// protocol. They're all one byte fields to simplify parsing. -type fwdTestNetworkEndpoint struct { - AddressableEndpointState - - nicID tcpip.NICID - proto *fwdTestNetworkProtocol - dispatcher TransportDispatcher - ep LinkEndpoint -} - -var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) - -func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { - return nil -} - -func (*fwdTestNetworkEndpoint) Enabled() bool { - return true -} - -func (*fwdTestNetworkEndpoint) Disable() {} - -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) -} - -func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { - return 123 -} - -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) -} - -func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen -} - -func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - -func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return f.proto.Number() -} - -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - // Add the protocol's header to the packet and send it to the link - // endpoint. - b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) - b[dstAddrOffset] = r.RemoteAddress[0] - b[srcAddrOffset] = r.LocalAddress[0] - b[protocolNumberOffset] = byte(params.Protocol) - - return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - panic("not implemented") -} - -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported -} - -func (f *fwdTestNetworkEndpoint) Close() { - f.AddressableEndpointState.Cleanup() -} - -// fwdTestNetworkProtocol is a network-layer protocol that implements Address -// resolution. -type fwdTestNetworkProtocol struct { - addrCache *linkAddrCache - neigh *neighborCache - addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) - onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) - - mu struct { - sync.RWMutex - forwarding bool - } -} - -var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) - -func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { - return fwdTestNetHeaderLen -} - -func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - -func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) -} - -func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) - if !ok { - return 0, false, false - } - return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true -} - -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { - e := &fwdTestNetworkEndpoint{ - nicID: nic.ID(), - proto: f, - dispatcher: dispatcher, - ep: nic.LinkEndpoint(), - } - e.AddressableEndpointState.Init(e) - return e -} - -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption -} - -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption -} - -func (*fwdTestNetworkProtocol) Close() {} - -func (*fwdTestNetworkProtocol) Wait() {} - -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { - if f.onLinkAddressResolved != nil { - time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) - }) - } - return nil -} - -func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if f.onResolveStaticAddress != nil { - return f.onResolveStaticAddress(addr) - } - return "", false -} - -func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return fwdTestNetNumber -} - -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) Forwarding() bool { - f.mu.RLock() - defer f.mu.RUnlock() - return f.mu.forwarding - -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { - f.mu.Lock() - defer f.mu.Unlock() - f.mu.forwarding = v -} - -// fwdTestPacketInfo holds all the information about an outbound packet. -type fwdTestPacketInfo struct { - RemoteLinkAddress tcpip.LinkAddress - LocalLinkAddress tcpip.LinkAddress - Pkt *PacketBuffer -} - -type fwdTestLinkEndpoint struct { - dispatcher NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - - // C is where outbound packets are queued. - C chan fwdTestPacketInfo -} - -// InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - e.InjectLinkAddr(protocol, "", pkt) -} - -// InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) -} - -// Attach saves the stack network-layer dispatcher for use later when packets -// are injected. -func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *fwdTestLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized -// during construction. -func (e *fwdTestLinkEndpoint) MTU() uint32 { - return e.mtu -} - -// Capabilities implements stack.LinkEndpoint.Capabilities. -func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { - caps := LinkEndpointCapabilities(0) - return caps | CapabilityResolutionRequired -} - -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - -// MaxHeaderLength returns the maximum size of the link layer header. Given it -// doesn't have a header, it just returns 0. -func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -// LinkAddress returns the link address of this endpoint. -func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return e.linkAddr -} - -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, - LocalLinkAddress: r.LocalLinkAddress, - Pkt: pkt, - } - - select { - case e.C <- p: - default: - } - - return nil -} - -// WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, pkt) - n++ - } - - return n, nil -} - -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), - } - - select { - case e.C <- p: - default: - } - - return nil -} - -// Wait implements stack.LinkEndpoint.Wait. -func (*fwdTestLinkEndpoint) Wait() {} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - panic("not implemented") -} - -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { - // Create a stack with the network protocol and two NICs. - s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, - UseNeighborCache: useNeighborCache, - }) - - if !useNeighborCache { - proto.addrCache = s.linkAddrCache - } - - // Enable forwarding. - s.SetForwarding(proto.Number(), true) - - // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "a", - } - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) - } - - // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ - C: make(chan fwdTestPacketInfo, 300), - mtu: fwdTestNetDefaultMTU, - linkAddr: "b", - } - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } - - if useNeighborCache { - // Control the neighbor cache for NIC 2. - nic, ok := s.nics[2] - if !ok { - t.Fatal("failed to get the neighbor cache for NIC 2") - } - proto.neigh = nic.neigh - } - - // Route all packets to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) - } - - return ep1, ep2 -} - -func TestForwardingWithStaticResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } - - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) - } -} - -func TestForwardingWithFakeResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - }, - }, - { - name: "neighborCache", - useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { - t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) - } - // Any address will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) - } -} - -func TestForwardingWithNoResolver(t *testing.T) { - // Create a network protocol without a resolver. - proto := &fwdTestNetworkProtocol{} - - // Whether or not we use the neighbor cache here does not matter since - // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) - - // inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - select { - case <-ep2.C: - t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): - } -} - -func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - } - }, - }, - }, - { - name: "neighborCache", - useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { - t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) - } - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) - } -} - -func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - }, - }, - { - name: "neighborCache", - useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { - t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) - } -} - -func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - }, - }, - { - name: "neighborCache", - useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { - t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] - - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) - } -} - -func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - }, - }, - }, - { - name: "neighborCache", - useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { - t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) - } -} diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go new file mode 100644 index 000000000..cf042309e --- /dev/null +++ b/pkg/tcpip/stack/forwarding_test.go @@ -0,0 +1,876 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fwdTestNetHeaderLen = 12 + fwdTestNetDefaultPrefixLen = 8 + + // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match + // the MTU of loopback interfaces on linux systems. + fwdTestNetDefaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 +) + +// fwdTestNetworkEndpoint is a network-layer protocol endpoint. +// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only +// use the first three: destination address, source address, and transport +// protocol. They're all one byte fields to simplify parsing. +type fwdTestNetworkEndpoint struct { + AddressableEndpointState + + nic NetworkInterface + proto *fwdTestNetworkProtocol + dispatcher TransportDispatcher +} + +var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) + +func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { + return nil +} + +func (*fwdTestNetworkEndpoint) Enabled() bool { + return true +} + +func (*fwdTestNetworkEndpoint) Disable() {} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.nic.MTU() - uint32(f.MaxHeaderLength()) +} + +func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) +} + +func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { + return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { + return 0 +} + +func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return f.proto.Number() +} + +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { + // Add the protocol's header to the packet and send it to the link + // endpoint. + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) + b[dstAddrOffset] = r.RemoteAddress[0] + b[srcAddrOffset] = r.LocalAddress[0] + b[protocolNumberOffset] = byte(params.Protocol) + + return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { + panic("not implemented") +} + +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func (f *fwdTestNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() +} + +// fwdTestNetworkProtocol is a network-layer protocol that implements Address +// resolution. +type fwdTestNetworkProtocol struct { + addrCache *linkAddrCache + neigh *neighborCache + addrResolveDelay time.Duration + onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) + onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) + + mu struct { + sync.RWMutex + forwarding bool + } +} + +var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) + +func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { + return fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { + return fwdTestNetDefaultPrefixLen +} + +func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) +} + +func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) + if !ok { + return 0, false, false + } + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true +} + +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { + e := &fwdTestNetworkEndpoint{ + nic: nic, + proto: f, + dispatcher: dispatcher, + } + e.AddressableEndpointState.Init(e) + return e +} + +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (*fwdTestNetworkProtocol) Close() {} + +func (*fwdTestNetworkProtocol) Wait() {} + +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + if f.onLinkAddressResolved != nil { + time.AfterFunc(f.addrResolveDelay, func() { + f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) + }) + } + return nil +} + +func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if f.onResolveStaticAddress != nil { + return f.onResolveStaticAddress(addr) + } + return "", false +} + +func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding + +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + +// fwdTestPacketInfo holds all the information about an outbound packet. +type fwdTestPacketInfo struct { + RemoteLinkAddress tcpip.LinkAddress + LocalLinkAddress tcpip.LinkAddress + Pkt *PacketBuffer +} + +type fwdTestLinkEndpoint struct { + dispatcher NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + + // C is where outbound packets are queued. + C chan fwdTestPacketInfo +} + +// InjectInbound injects an inbound packet. +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) +} + +// InjectLinkAddr injects an inbound packet with a remote link address. +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *fwdTestLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *fwdTestLinkEndpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { + caps := LinkEndpointCapabilities(0) + return caps | CapabilityResolutionRequired +} + +// GSOMaxSize returns the maximum GSO packet size. +func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { + return 1 << 15 +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + p := fwdTestPacketInfo{ + RemoteLinkAddress: r.RemoteLinkAddress, + LocalLinkAddress: r.LocalLinkAddress, + Pkt: pkt, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.WritePacket(r, gso, protocol, pkt) + n++ + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + p := fwdTestPacketInfo{ + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*fwdTestLinkEndpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { + // Create a stack with the network protocol and two NICs. + s := New(Options{ + NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + UseNeighborCache: useNeighborCache, + }) + + if !useNeighborCache { + proto.addrCache = s.linkAddrCache + } + + // Enable forwarding. + s.SetForwarding(proto.Number(), true) + + // NIC 1 has the link address "a", and added the network address 1. + ep1 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "a", + } + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC #1 failed:", err) + } + if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { + t.Fatal("AddAddress #1 failed:", err) + } + + // NIC 2 has the link address "b", and added the network address 2. + ep2 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "b", + } + if err := s.CreateNIC(2, ep2); err != nil { + t.Fatal("CreateNIC #2 failed:", err) + } + if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { + t.Fatal("AddAddress #2 failed:", err) + } + + if useNeighborCache { + // Control the neighbor cache for NIC 2. + nic, ok := s.nics[2] + if !ok { + t.Fatal("failed to get the neighbor cache for NIC 2") + } + proto.neigh = nic.neigh + } + + // Route all packets to NIC 2. + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) + } + + return ep1, ep2 +} + +func TestForwardingWithStaticResolver(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) + } +} + +func TestForwardingWithFakeResolver(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any address will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) + } +} + +func TestForwardingWithNoResolver(t *testing.T) { + // Create a network protocol without a resolver. + proto := &fwdTestNetworkProtocol{} + + // Whether or not we use the neighbor cache here does not matter since + // neither linkAddrCache nor neighborCache will be used. + ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) + + // inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + select { + case <-ep2.C: + t.Fatal("Packet should not be forwarded") + case <-time.After(time.Second): + } +} + +func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) + } +} + +func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) + } +} + +func TestForwardingWithFakeResolverManyPackets(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) + } +} + +func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) + } +} diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 9a72bec79..4d69a4de1 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index a265fff0a..e79abebca 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -227,8 +227,9 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ - id: entryTestNICID, - linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + + id: entryTestNICID, stack: &Stack{ clock: clock, nudDisp: &disp, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 6cf54cc89..8828cc5fe 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -32,10 +32,11 @@ var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { + LinkEndpoint + stack *Stack id tcpip.NICID name string - linkEP LinkEndpoint context NICContext stats NICStats @@ -91,10 +92,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ + LinkEndpoint: ep, + stack: stack, id: id, name: name, - linkEP: ep, context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), @@ -130,7 +132,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } - nic.linkEP.Attach(nic) + nic.LinkEndpoint.Attach(nic) return nic } @@ -220,7 +222,7 @@ func (n *NIC) remove() *tcpip.Error { } // Detach from link endpoint, so no packet comes in. - n.linkEP.Attach(nil) + n.LinkEndpoint.Attach(nil) return nil } @@ -240,7 +242,64 @@ func (n *NIC) isPromiscuousMode() bool { // IsLoopback implements NetworkInterface. func (n *NIC) IsLoopback() bool { - return n.linkEP.Capabilities()&CapabilityLoopback != 0 + return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0 +} + +// WritePacket implements NetworkLinkEndpoint. +func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 5.2 (for IPv6): + // Once the IP address of the next-hop node is known, the sender + // examines the Neighbor Cache for link-layer information about that + // neighbor. If no entry exists, the sender creates one, sets its state + // to INCOMPLETE, initiates Address Resolution, and then queues the data + // packet pending completion of address resolution. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + r := r.Clone() + n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + return nil + } + return err + } + + return n.writePacket(r, gso, protocol, pkt) +} + +func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Size() + + if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { + return err + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil +} + +// WritePackets implements NetworkLinkEndpoint. +func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution + // is being peformed like WritePacket. + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) + n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Size() + } + + n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + return writtenPackets, err } // setSpoofing enables or disables address spoofing. @@ -525,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // If no local link layer address is provided, assume it was sent // directly to this NIC. if local == "" { - local = n.linkEP.LinkAddress() + local = n.LinkEndpoint.LinkAddress() } // Are any packet type sockets listening for this network protocol? @@ -605,7 +664,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n := r.nic if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.linkEP.LinkAddress() + r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. @@ -620,21 +679,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(b/128629022): move this logic to route.WritePacket. // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) - // forwarder will release route. - return - } + + // pkt may have set its header and may not have enough headroom for + // link-layer header for the other link to prepend. Here we create a new + // packet to forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - r.Release() - return } - // The link-address resolution finished immediately. - n.forwardPacket(&r, protocol, pkt) r.Release() return } @@ -658,34 +717,11 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. - n.linkEP.AddHeader(local, remote, protocol, p) + n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - // pkt may have set its header and may not have enough headroom for link-layer - // header for the other link to prepend. Here we create a new packet to - // forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - }) - - // WritePacket takes ownership of fwdPkt, calculate numBytes first. - numBytes := fwdPkt.Size() - - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return - } - - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { @@ -796,11 +832,6 @@ func (n *NIC) Name() string { return n.name } -// LinkEndpoint implements NetworkInterface. -func (n *NIC) LinkEndpoint() LinkEndpoint { - return n.linkEP -} - // nudConfigs gets the NUD configurations for n. func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index fdd49b77f..97a96af62 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -33,8 +33,7 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil) type testIPv6Endpoint struct { AddressableEndpointState - nicID tcpip.NICID - linkEP LinkEndpoint + nic NetworkInterface protocol *testIPv6Protocol invalidatedRtr tcpip.Address @@ -57,12 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 { // MTU implements NetworkEndpoint.MTU. func (e *testIPv6Endpoint) MTU() uint32 { - return e.linkEP.MTU() - header.IPv6MinimumSize + return e.nic.MTU() - header.IPv6MinimumSize } // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize + return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } // WritePacket implements NetworkEndpoint.WritePacket. @@ -134,8 +133,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint implements NetworkProtocol.NewEndpoint. func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ - nicID: nic.ID(), - linkEP: nic.LinkEndpoint(), + nic: nic, protocol: p, } e.AddressableEndpointState.Init(e) diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go new file mode 100644 index 000000000..f838eda8d --- /dev/null +++ b/pkg/tcpip/stack/pending_packets.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // maxPendingResolutions is the maximum number of pending link-address + // resolutions. + maxPendingResolutions = 64 + maxPendingPacketsPerResolution = 256 +) + +type pendingPacket struct { + route *Route + proto tcpip.NetworkProtocolNumber + pkt *PacketBuffer +} + +// packetsPendingLinkResolution is a queue of packets pending link resolution. +// +// Once link resolution completes successfully, the packets will be written. +type packetsPendingLinkResolution struct { + sync.Mutex + + // The packets to send once the resolver completes. + packets map[<-chan struct{}][]pendingPacket + + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + cancelChans []chan struct{} +} + +func (f *packetsPendingLinkResolution) init() { + f.Lock() + defer f.Unlock() + f.packets = make(map[<-chan struct{}][]pendingPacket) +} + +func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + f.Lock() + defer f.Unlock() + + packets, ok := f.packets[ch] + if len(packets) == maxPendingPacketsPerResolution { + p := packets[0] + packets[0] = pendingPacket{} + packets = packets[1:] + p.route.Stats().IP.OutgoingPacketErrors.Increment() + p.route.Release() + } + + if l := len(packets); l >= maxPendingPacketsPerResolution { + panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) + } + + f.packets[ch] = append(packets, pendingPacket{ + route: r, + proto: proto, + pkt: pkt, + }) + + if ok { + return + } + + // Wait for the link-address resolution to complete. + cancel := f.newCancelChannelLocked() + go func() { + cancelled := false + select { + case <-ch: + case <-cancel: + cancelled = true + } + + f.Lock() + packets, ok := f.packets[ch] + delete(f.packets, ch) + f.Unlock() + + if !ok { + panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) + } + + for _, p := range packets { + if cancelled { + p.route.Stats().IP.OutgoingPacketErrors.Increment() + } else if _, err := p.route.Resolve(nil); err != nil { + p.route.Stats().IP.OutgoingPacketErrors.Increment() + } else { + p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + } + p.route.Release() + } + }() +} + +// newCancelChannel creates a channel that can cancel a pending forwarding +// activity. The oldest channel is closed if the number of open channels would +// exceed maxPendingResolutions. +func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { + if len(f.cancelChans) == maxPendingResolutions { + ch := f.cancelChans[0] + f.cancelChans[0] = nil + f.cancelChans = f.cancelChans[1:] + close(ch) + } + if l := len(f.cancelChans); l >= maxPendingResolutions { + panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) + } + + ch := make(chan struct{}) + f.cancelChans = append(f.cancelChans, ch) + return ch +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index be9bd8042..defb9129b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -475,6 +475,8 @@ type NDPEndpoint interface { // NetworkInterface is a network interface. type NetworkInterface interface { + NetworkLinkEndpoint + // ID returns the interface's ID. ID() tcpip.NICID @@ -488,9 +490,6 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool - - // LinkEndpoint returns the link endpoint backing the interface. - LinkEndpoint() LinkEndpoint } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -663,22 +662,15 @@ const ( CapabilitySoftwareGSO ) -// LinkEndpoint is the interface implemented by data link layer protocols (e.g., -// ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. When a link header exists, -// it sets each PacketBuffer's LinkHeader field before passing it up the -// stack. -type LinkEndpoint interface { +// NetworkLinkEndpoint is a data-link layer that supports sending network +// layer packets. +type NetworkLinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a // physical network doesn't exist, the limit is generally 64k, which // includes the maximum size of an IP packet. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the data link (and // lower level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -686,7 +678,7 @@ type LinkEndpoint interface { MaxHeaderLength() uint16 // LinkAddress returns the link address (typically a MAC) of the - // link endpoint. + // endpoint. LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the @@ -706,6 +698,19 @@ type LinkEndpoint interface { // offload is enabled. If it will be used for something else, it may // require to change syscall filters. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) +} + +// LinkEndpoint is the interface implemented by data link layer protocols (e.g., +// ethernet, loopback, raw) and used by network layer protocols to send packets +// out through the implementer's data link endpoint. When a link header exists, +// it sets each PacketBuffer's LinkHeader field before passing it up the +// stack. +type LinkEndpoint interface { + NetworkLinkEndpoint + + // Capabilities returns the set of capabilities supported by the + // endpoint. + Capabilities() LinkEndpointCapabilities // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. It takes ownership of vv. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index cc39c9a6a..25f80c1f8 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -72,21 +72,20 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } - linkEP := nic.LinkEndpoint() r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: linkEP.LinkAddress(), + LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), RemoteAddress: remoteAddr, addressEndpoint: addressEndpoint, nic: nic, Loop: loop, } - if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = nic.stack + r.linkCache = r.nic.stack } } @@ -116,23 +115,17 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint().Capabilities() + return r.nic.LinkEndpoint.Capabilities() } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.nic.getNetworkEndpoint(r.NetProto).(GSOEndpoint); ok { + if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 } -// ResolveWith immediately resolves a route with the specified remote link -// address. -func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr -} - // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -208,16 +201,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf return tcpip.ErrInvalidEndpointState } - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Size() - - if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil { - return err - } - - r.nic.stats.Tx.Packets.Increment() - r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - return nil + return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns @@ -227,15 +211,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } - n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) - r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - writtenBytes := 0 - for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Size() - } - - r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) - return n, err + return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -245,15 +221,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Data.Size() - - if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil { - return err - } - r.nic.stats.Tx.Packets.Increment() - r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - return nil + return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0bf20c0e1..3a07577c8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -436,9 +436,9 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // forwarder holds the packets that wait for their link-address resolutions - // to complete, and forwards them when each resolution is done. - forwarder *forwardQueue + // linkResQueue holds packets that are waiting for link resolution to + // complete. + linkResQueue packetsPendingLinkResolution // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. @@ -550,8 +550,8 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto +func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := t.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: netProto = header.IPv4ProtocolNumber @@ -565,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } } - switch len(e.ID.LocalAddress) { + switch len(t.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState @@ -577,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } switch { - case netProto == e.NetProto: - case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + case netProto == t.NetProto: + case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: if v6only { return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute } @@ -640,7 +640,6 @@ func New(opts Options) *Stack { useNeighborCache: opts.UseNeighborCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, @@ -653,6 +652,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } + s.linkResQueue.init() // Add specified network protocols. for _, netProtoFactory := range opts.NetworkProtocols { @@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } -// GetNICByName gets the NIC specified by name. -func (s *Stack) GetNICByName(name string) (*NIC, bool) { +// GetLinkEndpointByName gets the link endpoint specified by name. +func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { s.mu.RLock() defer s.mu.RUnlock() for _, nic := range s.nics { if nic.Name() == name { - return nic, true + return nic.LinkEndpoint } } - return nil, false + return nil } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { } nics[id] = NICInfo{ Name: nic.name, - LinkAddress: nic.linkEP.LinkAddress(), + LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, - MTU: nic.linkEP.MTU(), + MTU: nic.LinkEndpoint.MTU(), Stats: nic.stats, Context: nic.context, - ARPHardwareType: nic.linkEP.ARPHardwareType(), + ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), } } return nics @@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) } // Neighbors returns all IP to MAC address associations. @@ -1539,7 +1539,7 @@ func (s *Stack) Wait() { s.mu.RLock() defer s.mu.RUnlock() for _, n := range s.nics { - n.linkEP.Wait() + n.LinkEndpoint.Wait() } } @@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t // Add our own fake ethernet header. ethFields := header.EthernetFields{ - SrcAddr: nic.linkEP.LinkAddress(), + SrcAddr: nic.LinkEndpoint.LinkAddress(), DstAddr: dst, Type: netProto, } @@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t vv := buffer.View(fakeHeader).ToVectorisedView() vv.Append(payload) - if err := nic.linkEP.WriteRawPacket(vv); err != nil { + if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { return err } @@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) return tcpip.ErrUnknownDevice } - if err := nic.linkEP.WriteRawPacket(payload); err != nil { + if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { return err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index aa20f750b..38994cca1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,7 +21,6 @@ import ( "bytes" "fmt" "math" - "net" "sort" "testing" "time" @@ -35,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -77,10 +75,9 @@ type fakeNetworkEndpoint struct { enabled bool } - nicID tcpip.NICID + nic stack.NetworkInterface proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { @@ -103,7 +100,7 @@ func (f *fakeNetworkEndpoint) Disable() { } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) + return f.nic.MTU() - uint32(f.MaxHeaderLength()) } func (*fakeNetworkEndpoint) DefaultTTL() uint8 { @@ -135,7 +132,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fakeNetHeaderLen + return f.nic.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -164,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params return nil } - return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) + return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -216,10 +213,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ - nicID: nic.ID(), + nic: nic, proto: f, dispatcher: dispatcher, - ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e @@ -2106,7 +2102,7 @@ func TestNICStats(t *testing.T) { t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { + if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -3502,52 +3498,6 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } } -func TestResolveWith(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) - - remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) - } - defer r.Release() - - // Should initially require resolution. - if !r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = false, want = true") - } - - // Manually resolving the route should no longer require resolution. - r.ResolveWith("\x01") - if r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = true, want = false") - } -} - // TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its // associated address is removed should not cause a panic. func TestRouteReleaseAfterAddrRemoval(t *testing.T) { -- cgit v1.2.3