diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2020-10-09 09:08:57 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-10-09 09:11:18 -0700 |
commit | 8566decab094008d5f873cb679c972d5d60cc49a (patch) | |
tree | 0ba87479d8d36d3057366b7f26bddb1925b4d767 /pkg/tcpip/stack | |
parent | 07b1d7413e8b648b85fa9276a516732dd93276b4 (diff) |
Automated rollback of changelist 336185457
PiperOrigin-RevId: 336304024
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarder.go (renamed from pkg/tcpip/stack/pending_packets.go) | 60 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarder_test.go (renamed from pkg/tcpip/stack/forwarding_test.go) | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 125 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 48 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 62 |
11 files changed, 227 insertions, 179 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index eba97334e..2eaeab779 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( srcs = [ "addressable_endpoint_state.go", "conntrack.go", + "forwarder.go", "headertype_string.go", "icmp_rate_limit.go", "iptables.go", @@ -72,7 +73,6 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", - "pending_packets.go", "rand.go", "registration.go", "route.go", @@ -123,6 +123,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", @@ -138,7 +139,7 @@ go_test( name = "stack_test", size = "small", srcs = [ - "forwarding_test.go", + "forwarder_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/forwarder.go index f838eda8d..3eff141e6 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/forwarder.go @@ -29,60 +29,60 @@ const ( ) type pendingPacket struct { + nic *NIC route *Route proto tcpip.NetworkProtocolNumber pkt *PacketBuffer } -// packetsPendingLinkResolution is a queue of packets pending link resolution. -// -// Once link resolution completes successfully, the packets will be written. -type packetsPendingLinkResolution struct { +type forwardQueue struct { sync.Mutex // The packets to send once the resolver completes. - packets map[<-chan struct{}][]pendingPacket + packets map[<-chan struct{}][]*pendingPacket // FIFO of channels used to cancel the oldest goroutine waiting for // link-address resolution. cancelChans []chan struct{} } -func (f *packetsPendingLinkResolution) init() { - f.Lock() - defer f.Unlock() - f.packets = make(map[<-chan struct{}][]pendingPacket) +func newForwardQueue() *forwardQueue { + return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - f.Lock() - defer f.Unlock() +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + shouldWait := false + f.Lock() packets, ok := f.packets[ch] - if len(packets) == maxPendingPacketsPerResolution { + if !ok { + shouldWait = true + } + for len(packets) == maxPendingPacketsPerResolution { p := packets[0] - packets[0] = pendingPacket{} packets = packets[1:] - p.route.Stats().IP.OutgoingPacketErrors.Increment() + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() p.route.Release() } - if l := len(packets); l >= maxPendingPacketsPerResolution { panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) } - - f.packets[ch] = append(packets, pendingPacket{ + f.packets[ch] = append(packets, &pendingPacket{ + nic: n, route: r, - proto: proto, + proto: protocol, pkt: pkt, }) + f.Unlock() - if ok { + if !shouldWait { return } // Wait for the link-address resolution to complete. - cancel := f.newCancelChannelLocked() + // Start a goroutine with a forwarding-cancel channel so that we can + // limit the maximum number of goroutines running concurrently. + cancel := f.newCancelChannel() go func() { cancelled := false select { @@ -92,21 +92,17 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } f.Lock() - packets, ok := f.packets[ch] + packets := f.packets[ch] delete(f.packets, ch) f.Unlock() - if !ok { - panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) - } - for _, p := range packets { if cancelled { - p.route.Stats().IP.OutgoingPacketErrors.Increment() + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() } else if _, err := p.route.Resolve(nil); err != nil { - p.route.Stats().IP.OutgoingPacketErrors.Increment() + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.nic.forwardPacket(p.route, p.proto, p.pkt) } p.route.Release() } @@ -116,10 +112,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro // newCancelChannel creates a channel that can cancel a pending forwarding // activity. The oldest channel is closed if the number of open channels would // exceed maxPendingResolutions. -func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { +func (f *forwardQueue) newCancelChannel() chan struct{} { + f.Lock() + defer f.Unlock() + if len(f.cancelChans) == maxPendingResolutions { ch := f.cancelChans[0] - f.cancelChans[0] = nil f.cancelChans = f.cancelChans[1:] close(ch) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarder_test.go index cf042309e..4e4b00a92 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -48,9 +48,10 @@ const ( type fwdTestNetworkEndpoint struct { AddressableEndpointState - nic NetworkInterface + nicID tcpip.NICID proto *fwdTestNetworkProtocol dispatcher TransportDispatcher + ep LinkEndpoint } var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) @@ -66,7 +67,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool { func (*fwdTestNetworkEndpoint) Disable() {} func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { @@ -79,7 +80,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen + return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen } func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -98,7 +99,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) + return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. @@ -158,9 +159,10 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ - nic: nic, + nicID: nic.ID(), proto: f, dispatcher: dispatcher, + ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 4d69a4de1..9a72bec79 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index e79abebca..a265fff0a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -227,9 +227,8 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ - LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint - - id: entryTestNICID, + id: entryTestNICID, + linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint stack: &Stack{ clock: clock, nudDisp: &disp, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8828cc5fe..6cf54cc89 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -32,11 +32,10 @@ var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - LinkEndpoint - stack *Stack id tcpip.NICID name string + linkEP LinkEndpoint context NICContext stats NICStats @@ -92,11 +91,10 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - LinkEndpoint: ep, - stack: stack, id: id, name: name, + linkEP: ep, context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), @@ -132,7 +130,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } - nic.LinkEndpoint.Attach(nic) + nic.linkEP.Attach(nic) return nic } @@ -222,7 +220,7 @@ func (n *NIC) remove() *tcpip.Error { } // Detach from link endpoint, so no packet comes in. - n.LinkEndpoint.Attach(nil) + n.linkEP.Attach(nil) return nil } @@ -242,64 +240,7 @@ func (n *NIC) isPromiscuousMode() bool { // IsLoopback implements NetworkInterface. func (n *NIC) IsLoopback() bool { - return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0 -} - -// WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - // As per relevant RFCs, we should queue packets while we wait for link - // resolution to complete. - // - // RFC 1122 section 2.3.2.2 (for IPv4): - // The link layer SHOULD save (rather than discard) at least - // one (the latest) packet of each set of packets destined to - // the same unresolved IP address, and transmit the saved - // packet when the address has been resolved. - // - // RFC 4861 section 5.2 (for IPv6): - // Once the IP address of the next-hop node is known, the sender - // examines the Neighbor Cache for link-layer information about that - // neighbor. If no entry exists, the sender creates one, sets its state - // to INCOMPLETE, initiates Address Resolution, and then queues the data - // packet pending completion of address resolution. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) - return nil - } - return err - } - - return n.writePacket(r, gso, protocol, pkt) -} - -func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Size() - - if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { - return err - } - - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - return nil -} - -// WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution - // is being peformed like WritePacket. - writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) - writtenBytes := 0 - for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Size() - } - - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) - return writtenPackets, err + return n.linkEP.Capabilities()&CapabilityLoopback != 0 } // setSpoofing enables or disables address spoofing. @@ -584,7 +525,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // If no local link layer address is provided, assume it was sent // directly to this NIC. if local == "" { - local = n.LinkEndpoint.LinkAddress() + local = n.linkEP.LinkAddress() } // Are any packet type sockets listening for this network protocol? @@ -664,7 +605,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n := r.nic if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() + r.LocalLinkAddress = n.linkEP.LinkAddress() r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. @@ -679,21 +620,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. + // TODO(b/128629022): move this logic to route.WritePacket. // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) + // forwarder will release route. + return + } n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() + r.Release() + return } + // The link-address resolution finished immediately. + n.forwardPacket(&r, protocol, pkt) r.Release() return } @@ -717,11 +658,34 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. - n.LinkEndpoint.AddHeader(local, remote, protocol, p) + n.linkEP.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) } } +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() + + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) +} + // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { @@ -832,6 +796,11 @@ func (n *NIC) Name() string { return n.name } +// LinkEndpoint implements NetworkInterface. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // nudConfigs gets the NUD configurations for n. func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 97a96af62..fdd49b77f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -33,7 +33,8 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil) type testIPv6Endpoint struct { AddressableEndpointState - nic NetworkInterface + nicID tcpip.NICID + linkEP LinkEndpoint protocol *testIPv6Protocol invalidatedRtr tcpip.Address @@ -56,12 +57,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 { // MTU implements NetworkEndpoint.MTU. func (e *testIPv6Endpoint) MTU() uint32 { - return e.nic.MTU() - header.IPv6MinimumSize + return e.linkEP.MTU() - header.IPv6MinimumSize } // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { - return e.nic.MaxHeaderLength() + header.IPv6MinimumSize + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize } // WritePacket implements NetworkEndpoint.WritePacket. @@ -133,7 +134,8 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint implements NetworkProtocol.NewEndpoint. func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ - nic: nic, + nicID: nic.ID(), + linkEP: nic.LinkEndpoint(), protocol: p, } e.AddressableEndpointState.Init(e) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..be9bd8042 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -475,8 +475,6 @@ type NDPEndpoint interface { // NetworkInterface is a network interface. type NetworkInterface interface { - NetworkLinkEndpoint - // ID returns the interface's ID. ID() tcpip.NICID @@ -490,6 +488,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // LinkEndpoint returns the link endpoint backing the interface. + LinkEndpoint() LinkEndpoint } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -662,15 +663,22 @@ const ( CapabilitySoftwareGSO ) -// NetworkLinkEndpoint is a data-link layer that supports sending network -// layer packets. -type NetworkLinkEndpoint interface { +// LinkEndpoint is the interface implemented by data link layer protocols (e.g., +// ethernet, loopback, raw) and used by network layer protocols to send packets +// out through the implementer's data link endpoint. When a link header exists, +// it sets each PacketBuffer's LinkHeader field before passing it up the +// stack. +type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a // physical network doesn't exist, the limit is generally 64k, which // includes the maximum size of an IP packet. MTU() uint32 + // Capabilities returns the set of capabilities supported by the + // endpoint. + Capabilities() LinkEndpointCapabilities + // MaxHeaderLength returns the maximum size the data link (and // lower level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -678,7 +686,7 @@ type NetworkLinkEndpoint interface { MaxHeaderLength() uint16 // LinkAddress returns the link address (typically a MAC) of the - // endpoint. + // link endpoint. LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the @@ -698,19 +706,6 @@ type NetworkLinkEndpoint interface { // offload is enabled. If it will be used for something else, it may // require to change syscall filters. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) -} - -// LinkEndpoint is the interface implemented by data link layer protocols (e.g., -// ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. When a link header exists, -// it sets each PacketBuffer's LinkHeader field before passing it up the -// stack. -type LinkEndpoint interface { - NetworkLinkEndpoint - - // Capabilities returns the set of capabilities supported by the - // endpoint. - Capabilities() LinkEndpointCapabilities // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. It takes ownership of vv. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 710a02ac3..cc39c9a6a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -72,20 +72,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } + linkEP := nic.LinkEndpoint() r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), + LocalLinkAddress: linkEP.LinkAddress(), RemoteAddress: remoteAddr, addressEndpoint: addressEndpoint, nic: nic, Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = nic.stack } } @@ -115,7 +116,7 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() + return r.nic.LinkEndpoint().Capabilities() } // GSOMaxSize returns the maximum GSO packet size. @@ -126,6 +127,12 @@ func (r *Route) GSOMaxSize() uint32 { return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -201,7 +208,16 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Size() + + if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil { + return err + } + + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil } // WritePackets writes a list of n packets through the given route and returns @@ -211,7 +227,15 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Size() + } + + r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -221,7 +245,15 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Data.Size() + + if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil { + return err + } + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil } // DefaultTTL returns the default TTL of the underlying network endpoint. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..0bf20c0e1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -436,9 +436,9 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // linkResQueue holds packets that are waiting for link resolution to - // complete. - linkResQueue packetsPendingLinkResolution + // forwarder holds the packets that wait for their link-address resolutions + // to complete, and forwards them when each resolution is done. + forwarder *forwardQueue // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. @@ -550,8 +550,8 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := t.NetProto +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: netProto = header.IPv4ProtocolNumber @@ -565,7 +565,7 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } } - switch len(t.ID.LocalAddress) { + switch len(e.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState @@ -577,8 +577,8 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } switch { - case netProto == t.NetProto: - case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: + case netProto == e.NetProto: + case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: if v6only { return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute } @@ -640,6 +640,7 @@ func New(opts Options) *Stack { useNeighborCache: opts.UseNeighborCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, + forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, @@ -652,7 +653,6 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } - s.linkResQueue.init() // Add specified network protocols. for _, netProtoFactory := range opts.NetworkProtocols { @@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } -// GetLinkEndpointByName gets the link endpoint specified by name. -func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { s.mu.RLock() defer s.mu.RUnlock() for _, nic := range s.nics { if nic.Name() == name { - return nic.LinkEndpoint + return nic, true } } - return nil + return nil, false } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { } nics[id] = NICInfo{ Name: nic.name, - LinkAddress: nic.LinkEndpoint.LinkAddress(), + LinkAddress: nic.linkEP.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, - MTU: nic.LinkEndpoint.MTU(), + MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, - ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics @@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // Neighbors returns all IP to MAC address associations. @@ -1539,7 +1539,7 @@ func (s *Stack) Wait() { s.mu.RLock() defer s.mu.RUnlock() for _, n := range s.nics { - n.LinkEndpoint.Wait() + n.linkEP.Wait() } } @@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t // Add our own fake ethernet header. ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), + SrcAddr: nic.linkEP.LinkAddress(), DstAddr: dst, Type: netProto, } @@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t vv := buffer.View(fakeHeader).ToVectorisedView() vv.Append(payload) - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { + if err := nic.linkEP.WriteRawPacket(vv); err != nil { return err } @@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) return tcpip.ErrUnknownDevice } - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { + if err := nic.linkEP.WriteRawPacket(payload); err != nil { return err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 38994cca1..aa20f750b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "net" "sort" "testing" "time" @@ -34,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -75,9 +77,10 @@ type fakeNetworkEndpoint struct { enabled bool } - nic stack.NetworkInterface + nicID tcpip.NICID proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { @@ -100,7 +103,7 @@ func (f *fakeNetworkEndpoint) Disable() { } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.nic.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fakeNetworkEndpoint) DefaultTTL() uint8 { @@ -132,7 +135,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.nic.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -161,7 +164,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params return nil } - return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -213,9 +216,10 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ - nic: nic, + nicID: nic.ID(), proto: f, dispatcher: dispatcher, + ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e @@ -2102,7 +2106,7 @@ func TestNICStats(t *testing.T) { t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -3498,6 +3502,52 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } } +func TestResolveWith(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + ep := channel.New(0, defaultMTU, "") + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) + + remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) + } + defer r.Release() + + // Should initially require resolution. + if !r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = false, want = true") + } + + // Manually resolving the route should no longer require resolution. + r.ResolveWith("\x01") + if r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = true, want = false") + } +} + // TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its // associated address is removed should not cause a panic. func TestRouteReleaseAfterAddrRemoval(t *testing.T) { |