diff options
-rw-r--r-- | pkg/sentry/devices/tundev/tundev.go | 14 | ||||
-rw-r--r-- | pkg/sentry/fs/dev/net_tun.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/header/eth.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/link/tun/device.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 51 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 36 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarder.go (renamed from pkg/tcpip/stack/pending_packets.go) | 60 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 125 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 48 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 42 |
16 files changed, 269 insertions, 347 deletions
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 655ea549b..0b701a289 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -16,8 +16,6 @@ package tundev import ( - "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -28,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -87,16 +84,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - created, err := fd.device.SetIff(stack.Stack, req.Name(), flags) - if err == nil && created { - // Always start with an ARP address for interfaces so they can handle ARP - // packets. - nicID := fd.device.NICID() - if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) - } - } - return 0, err + return 0, fd.device.SetIff(stack.Stack, req.Name(), flags) case linux.TUNGETIFF: var req linux.IFReq diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 19ffdec47..5f8c9b5a2 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -15,8 +15,6 @@ package dev import ( - "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -27,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +60,7 @@ func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMod } // GetFile implements fs.InodeOperations.GetFile. -func (*netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { +func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil } @@ -83,12 +80,12 @@ type netTunFileOperations struct { var _ fs.FileOperations = (*netTunFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (n *netTunFileOperations) Release(ctx context.Context) { - n.device.Release(ctx) +func (fops *netTunFileOperations) Release(ctx context.Context) { + fops.device.Release(ctx) } // Ioctl implements fs.FileOperations.Ioctl. -func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { request := args[1].Uint() data := args[2].Pointer() @@ -112,25 +109,16 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - created, err := n.device.SetIff(stack.Stack, req.Name(), flags) - if err == nil && created { - // Always start with an ARP address for interfaces so they can handle ARP - // packets. - nicID := n.device.NICID() - if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) - } - } - return 0, err + return 0, fops.device.SetIff(stack.Stack, req.Name(), flags) case linux.TUNGETIFF: var req linux.IFReq - copy(req.IFName[:], n.device.Name()) + copy(req.IFName[:], fops.device.Name()) // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. - flags := n.device.Flags() | linux.IFF_NOFILTER + flags := fops.device.Flags() | linux.IFF_NOFILTER usermem.ByteOrder.PutUint16(req.Data[:], flags) _, err := req.CopyOut(t, data) @@ -142,41 +130,41 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user } // Write implements fs.FileOperations.Write. -func (n *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { +func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { data := make([]byte, src.NumBytes()) if _, err := src.CopyIn(ctx, data); err != nil { return 0, err } - return n.device.Write(data) + return fops.device.Write(data) } // Read implements fs.FileOperations.Read. -func (n *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - data, err := n.device.Read() +func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + data, err := fops.device.Read() if err != nil { return 0, err } - bytesCopied, err := dst.CopyOut(ctx, data) - if bytesCopied > 0 && bytesCopied < len(data) { + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { // Not an error for partial copying. Packet truncated. err = nil } - return int64(bytesCopied), err + return int64(n), err } // Readiness implements watier.Waitable.Readiness. -func (n *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { - return n.device.Readiness(mask) +func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return fops.device.Readiness(mask) } // EventRegister implements watier.Waitable.EventRegister. -func (n *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - n.device.EventRegister(e, mask) +func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fops.device.EventRegister(e, mask) } // EventUnregister implements watier.Waitable.EventUnregister. -func (n *netTunFileOperations) EventUnregister(e *waiter.Entry) { - n.device.EventUnregister(e) +func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { + fops.device.EventUnregister(e) } // isNetTunSupported returns whether /dev/net/tun device is supported for s. diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index 95ade0e5c..eaface8cb 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -117,31 +117,25 @@ func (b Ethernet) Encode(e *EthernetFields) { copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) } -// IsMulticastEthernetAddress returns true if the address is a multicast -// ethernet address. -func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool { - if len(addr) != EthernetAddressSize { - return false - } - - return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 -} - -// IsValidUnicastEthernetAddress returns true if the address is a unicast +// IsValidUnicastEthernetAddress returns true if addr is a valid unicast // ethernet address. func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { + // Must be of the right length. if len(addr) != EthernetAddressSize { return false } + // Must not be unspecified. if addr == unspecifiedEthernetAddress { return false } + // Must not be a multicast. if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { return false } + // addr is a valid unicast ethernet address. return true } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f94491026..b6ddbe81e 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -76,29 +76,13 @@ func (d *Device) Release(ctx context.Context) { } } -// NICID returns the NIC ID of the device. -// -// Must only be called after the device has been attached to an endpoint. -func (d *Device) NICID() tcpip.NICID { - d.mu.RLock() - defer d.mu.RUnlock() - - if d.endpoint == nil { - panic("called NICID on a device that has not been attached") - } - - return d.endpoint.nicID -} - // SetIff services TUNSETIFF ioctl(2) request. -// -// Returns true if a new NIC was created; false if an existing one was attached. -func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) { +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { d.mu.Lock() defer d.mu.Unlock() if d.endpoint != nil { - return false, syserror.EINVAL + return syserror.EINVAL } // Input validations. @@ -106,7 +90,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) isTap := flags&linux.IFF_TAP != 0 supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { - return false, syserror.EINVAL + return syserror.EINVAL } prefix := "tun" @@ -119,32 +103,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) linkCaps |= stack.CapabilityResolutionRequired } - endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps) + endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return false, syserror.EINVAL + return syserror.EINVAL } d.endpoint = endpoint d.notifyHandle = d.endpoint.AddNotify(d) d.flags = flags - return created, nil + return nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { for { // 1. Try to attach to an existing NIC. if name != "" { - if linkEP := s.GetLinkEndpointByName(name); linkEP != nil { - endpoint, ok := linkEP.(*tunEndpoint) + if nic, found := s.GetNICByName(name); found { + endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, false, syserror.EOPNOTSUPP + return nil, syserror.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. continue } - return endpoint, false, nil + return endpoint, nil } } @@ -167,12 +151,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE }) switch err { case nil: - return endpoint, true, nil + return endpoint, nil case tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: - return nil, false, syserror.EINVAL + return nil, syserror.EINVAL } } } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7df77c66e..b47a7be51 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -49,6 +49,7 @@ type endpoint struct { enabled uint32 nic stack.NetworkInterface + linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler } @@ -91,12 +92,12 @@ func (e *endpoint) DefaultTTL() uint8 { } func (e *endpoint) MTU() uint32 { - lmtu := e.nic.MTU() + lmtu := e.linkEP.MTU() return lmtu - uint32(e.MaxHeaderLength()) } func (e *endpoint) MaxHeaderLength() uint16 { - return e.nic.MaxHeaderLength() + header.ARPSize + return e.linkEP.MaxHeaderLength() + header.ARPSize } func (e *endpoint) Close() { @@ -153,25 +154,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - // As per RFC 826, under Packet Reception: - // Swap hardware and protocol fields, putting the local hardware and - // protocol addresses in the sender fields. - // - // Send the packet to the (new) target hardware address on the same - // hardware on which the request was received. - origSender := h.HardwareAddressSender() - r.RemoteLinkAddress = tcpip.LinkAddress(origSender) - respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, }) - packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) + packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), origSender) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) + _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -214,6 +207,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e := &endpoint{ protocol: p, nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, } @@ -229,7 +223,6 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - NetProto: ProtocolNumber, RemoteLinkAddress: remoteLinkAddr, } if len(r.RemoteLinkAddress) == 0 { diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3407755ed..eab9a530c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -102,6 +102,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + remoteLinkAddr := r.RemoteLinkAddress + // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). @@ -117,6 +119,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } defer r.Release() + // Use the remote link address from the incoming packet. + r.ResolveWith(remoteLinkAddr) + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -239,7 +244,13 @@ func (*icmpReasonProtoUnreachable) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + sent := r.Stats().ICMP.V4PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -268,25 +279,6 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - // Even if we were able to receive a packet from some remote, we may not have - // a route to it - the remote may be blocked via routing rules. We must always - // consult our routing table and find a route to the remote before sending any - // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil - - sent := p.stack.Stats().ICMP.V4PacketsSent - if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() @@ -337,11 +329,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // least 8 bytes of the payload must be included. Today linux and other // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. - mtu := int(route.MTU()) + mtu := int(r.MTU()) if mtu > header.IPv4MinimumProcessableDatagramSize { mtu = header.IPv4MinimumProcessableDatagramSize } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { @@ -386,11 +378,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) counter := sent.DstUnreachable - if err := route.WritePacket( + if err := r.WritePacket( nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, - TTL: route.DefaultTTL(), + TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, }, icmpPkt, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 115fb1ab0..99274dd45 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -66,6 +66,7 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface + linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol @@ -86,6 +87,7 @@ type endpoint struct { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, + linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, } @@ -176,18 +178,18 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + return calculateMTU(e.linkEP.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize + return e.linkEP.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // GSOMaxSize returns the maximum GSO packet size. func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.nic.(stack.GSOEndpoint); ok { + if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -208,7 +210,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint for { fragPkt, more := buildNextFragment(&pf, networkHeader) - if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) return err } @@ -281,10 +283,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) + if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, e.linkEP.MTU(), pkt) } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -314,7 +316,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) @@ -341,7 +343,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in @@ -402,7 +404,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } - if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -510,13 +512,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7be35c78b..37c169a5d 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -440,6 +440,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } + remoteLinkAddr := r.RemoteLinkAddress + // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := r.LocalAddress @@ -454,6 +456,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } defer r.Release() + // Use the link address from the source of the original packet. + r.ResolveWith(remoteLinkAddr) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data, @@ -737,7 +742,14 @@ func (*icmpReasonPortUnreachable) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + stats := r.Stats().ICMP + sent := stats.V6PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // @@ -768,26 +780,6 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - // Even if we were able to receive a packet from some remote, we may not have - // a route to it - the remote may be blocked via routing rules. We must always - // consult our routing table and find a route to the remote before sending any - // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil - - stats := p.stack.Stats().ICMP - sent := stats.V6PacketsSent - if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { @@ -814,11 +806,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // packet that caused the error) as possible without making // the error message packet exceed the minimum IPv6 MTU // [IPv6]. - mtu := int(route.MTU()) + mtu := int(r.MTU()) if mtu > header.IPv6MinimumMTU { mtu = header.IPv6MinimumMTU } - headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize available := int(mtu) - headerLen if available < header.IPv6MinimumSize { return nil @@ -851,16 +843,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data)) - if err := route.WritePacket( - nil, /* gso */ - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: route.DefaultTTL(), - TOS: stack.DefaultTOS, - }, - newPkt, - ); err != nil { + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) + err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) + if err != nil { sent.Dropped.Increment() return err } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4f360df2c..826342c4f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -66,6 +66,7 @@ var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface + linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler dispatcher stack.TransportDispatcher @@ -363,18 +364,18 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + return calculateMTU(e.linkEP.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.nic.MaxHeaderLength() + header.IPv6MinimumSize + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize } // GSOMaxSize returns the maximum GSO packet size. func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.nic.(stack.GSOEndpoint); ok { + if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -395,7 +396,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s } func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) + return pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) } // handleFragments fragments pkt and calls the handler function on each @@ -476,19 +477,19 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, e.linkEP.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. - return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt) }) r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -510,7 +511,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe e.addIPHeader(r, pb, params) if e.packetMustBeFragmented(pb, gso) { current := pb - _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + _, _, err := e.handleFragments(r, gso, e.linkEP.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(current, fragPkt) current = current.Next() @@ -535,7 +536,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) @@ -562,7 +563,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in @@ -642,7 +643,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: previousHeaderStart, }, pkt) @@ -681,7 +682,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -706,7 +707,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // header, so we just make sure Segments Left is zero before processing // the next extension header. if extHdr.SegmentsLeft() != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), }, pkt) @@ -858,7 +859,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -895,7 +896,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -916,7 +917,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -926,7 +927,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = returnError(r, &icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -1301,6 +1302,7 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, dispatcher: dispatcher, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 40da011f8..48a4c65e3 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1289,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.ep.nic.LinkAddress() + linkAddr := ndp.ep.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { return false } diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/forwarder.go index f838eda8d..3eff141e6 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/forwarder.go @@ -29,60 +29,60 @@ const ( ) type pendingPacket struct { + nic *NIC route *Route proto tcpip.NetworkProtocolNumber pkt *PacketBuffer } -// packetsPendingLinkResolution is a queue of packets pending link resolution. -// -// Once link resolution completes successfully, the packets will be written. -type packetsPendingLinkResolution struct { +type forwardQueue struct { sync.Mutex // The packets to send once the resolver completes. - packets map[<-chan struct{}][]pendingPacket + packets map[<-chan struct{}][]*pendingPacket // FIFO of channels used to cancel the oldest goroutine waiting for // link-address resolution. cancelChans []chan struct{} } -func (f *packetsPendingLinkResolution) init() { - f.Lock() - defer f.Unlock() - f.packets = make(map[<-chan struct{}][]pendingPacket) +func newForwardQueue() *forwardQueue { + return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - f.Lock() - defer f.Unlock() +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + shouldWait := false + f.Lock() packets, ok := f.packets[ch] - if len(packets) == maxPendingPacketsPerResolution { + if !ok { + shouldWait = true + } + for len(packets) == maxPendingPacketsPerResolution { p := packets[0] - packets[0] = pendingPacket{} packets = packets[1:] - p.route.Stats().IP.OutgoingPacketErrors.Increment() + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() p.route.Release() } - if l := len(packets); l >= maxPendingPacketsPerResolution { panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) } - - f.packets[ch] = append(packets, pendingPacket{ + f.packets[ch] = append(packets, &pendingPacket{ + nic: n, route: r, - proto: proto, + proto: protocol, pkt: pkt, }) + f.Unlock() - if ok { + if !shouldWait { return } // Wait for the link-address resolution to complete. - cancel := f.newCancelChannelLocked() + // Start a goroutine with a forwarding-cancel channel so that we can + // limit the maximum number of goroutines running concurrently. + cancel := f.newCancelChannel() go func() { cancelled := false select { @@ -92,21 +92,17 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } f.Lock() - packets, ok := f.packets[ch] + packets := f.packets[ch] delete(f.packets, ch) f.Unlock() - if !ok { - panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) - } - for _, p := range packets { if cancelled { - p.route.Stats().IP.OutgoingPacketErrors.Increment() + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() } else if _, err := p.route.Resolve(nil); err != nil { - p.route.Stats().IP.OutgoingPacketErrors.Increment() + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.nic.forwardPacket(p.route, p.proto, p.pkt) } p.route.Release() } @@ -116,10 +112,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro // newCancelChannel creates a channel that can cancel a pending forwarding // activity. The oldest channel is closed if the number of open channels would // exceed maxPendingResolutions. -func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { +func (f *forwardQueue) newCancelChannel() chan struct{} { + f.Lock() + defer f.Unlock() + if len(f.cancelChans) == maxPendingResolutions { ch := f.cancelChans[0] - f.cancelChans[0] = nil f.cancelChans = f.cancelChans[1:] close(ch) } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 4d69a4de1..9a72bec79 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8828cc5fe..6cf54cc89 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -32,11 +32,10 @@ var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - LinkEndpoint - stack *Stack id tcpip.NICID name string + linkEP LinkEndpoint context NICContext stats NICStats @@ -92,11 +91,10 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - LinkEndpoint: ep, - stack: stack, id: id, name: name, + linkEP: ep, context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), @@ -132,7 +130,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } - nic.LinkEndpoint.Attach(nic) + nic.linkEP.Attach(nic) return nic } @@ -222,7 +220,7 @@ func (n *NIC) remove() *tcpip.Error { } // Detach from link endpoint, so no packet comes in. - n.LinkEndpoint.Attach(nil) + n.linkEP.Attach(nil) return nil } @@ -242,64 +240,7 @@ func (n *NIC) isPromiscuousMode() bool { // IsLoopback implements NetworkInterface. func (n *NIC) IsLoopback() bool { - return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0 -} - -// WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - // As per relevant RFCs, we should queue packets while we wait for link - // resolution to complete. - // - // RFC 1122 section 2.3.2.2 (for IPv4): - // The link layer SHOULD save (rather than discard) at least - // one (the latest) packet of each set of packets destined to - // the same unresolved IP address, and transmit the saved - // packet when the address has been resolved. - // - // RFC 4861 section 5.2 (for IPv6): - // Once the IP address of the next-hop node is known, the sender - // examines the Neighbor Cache for link-layer information about that - // neighbor. If no entry exists, the sender creates one, sets its state - // to INCOMPLETE, initiates Address Resolution, and then queues the data - // packet pending completion of address resolution. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) - return nil - } - return err - } - - return n.writePacket(r, gso, protocol, pkt) -} - -func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Size() - - if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { - return err - } - - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - return nil -} - -// WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution - // is being peformed like WritePacket. - writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) - writtenBytes := 0 - for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Size() - } - - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) - return writtenPackets, err + return n.linkEP.Capabilities()&CapabilityLoopback != 0 } // setSpoofing enables or disables address spoofing. @@ -584,7 +525,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // If no local link layer address is provided, assume it was sent // directly to this NIC. if local == "" { - local = n.LinkEndpoint.LinkAddress() + local = n.linkEP.LinkAddress() } // Are any packet type sockets listening for this network protocol? @@ -664,7 +605,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n := r.nic if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() + r.LocalLinkAddress = n.linkEP.LinkAddress() r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. @@ -679,21 +620,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. + // TODO(b/128629022): move this logic to route.WritePacket. // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) + // forwarder will release route. + return + } n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() + r.Release() + return } + // The link-address resolution finished immediately. + n.forwardPacket(&r, protocol, pkt) r.Release() return } @@ -717,11 +658,34 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. - n.LinkEndpoint.AddHeader(local, remote, protocol, p) + n.linkEP.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) } } +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() + + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) +} + // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { @@ -832,6 +796,11 @@ func (n *NIC) Name() string { return n.name } +// LinkEndpoint implements NetworkInterface. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // nudConfigs gets the NUD configurations for n. func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..be9bd8042 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -475,8 +475,6 @@ type NDPEndpoint interface { // NetworkInterface is a network interface. type NetworkInterface interface { - NetworkLinkEndpoint - // ID returns the interface's ID. ID() tcpip.NICID @@ -490,6 +488,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // LinkEndpoint returns the link endpoint backing the interface. + LinkEndpoint() LinkEndpoint } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -662,15 +663,22 @@ const ( CapabilitySoftwareGSO ) -// NetworkLinkEndpoint is a data-link layer that supports sending network -// layer packets. -type NetworkLinkEndpoint interface { +// LinkEndpoint is the interface implemented by data link layer protocols (e.g., +// ethernet, loopback, raw) and used by network layer protocols to send packets +// out through the implementer's data link endpoint. When a link header exists, +// it sets each PacketBuffer's LinkHeader field before passing it up the +// stack. +type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a // physical network doesn't exist, the limit is generally 64k, which // includes the maximum size of an IP packet. MTU() uint32 + // Capabilities returns the set of capabilities supported by the + // endpoint. + Capabilities() LinkEndpointCapabilities + // MaxHeaderLength returns the maximum size the data link (and // lower level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -678,7 +686,7 @@ type NetworkLinkEndpoint interface { MaxHeaderLength() uint16 // LinkAddress returns the link address (typically a MAC) of the - // endpoint. + // link endpoint. LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the @@ -698,19 +706,6 @@ type NetworkLinkEndpoint interface { // offload is enabled. If it will be used for something else, it may // require to change syscall filters. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) -} - -// LinkEndpoint is the interface implemented by data link layer protocols (e.g., -// ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. When a link header exists, -// it sets each PacketBuffer's LinkHeader field before passing it up the -// stack. -type LinkEndpoint interface { - NetworkLinkEndpoint - - // Capabilities returns the set of capabilities supported by the - // endpoint. - Capabilities() LinkEndpointCapabilities // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. It takes ownership of vv. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 710a02ac3..cc39c9a6a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -72,20 +72,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } + linkEP := nic.LinkEndpoint() r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), + LocalLinkAddress: linkEP.LinkAddress(), RemoteAddress: remoteAddr, addressEndpoint: addressEndpoint, nic: nic, Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = nic.stack } } @@ -115,7 +116,7 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() + return r.nic.LinkEndpoint().Capabilities() } // GSOMaxSize returns the maximum GSO packet size. @@ -126,6 +127,12 @@ func (r *Route) GSOMaxSize() uint32 { return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -201,7 +208,16 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Size() + + if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil { + return err + } + + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil } // WritePackets writes a list of n packets through the given route and returns @@ -211,7 +227,15 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Size() + } + + r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -221,7 +245,15 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Data.Size() + + if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil { + return err + } + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil } // DefaultTTL returns the default TTL of the underlying network endpoint. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..0bf20c0e1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -436,9 +436,9 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // linkResQueue holds packets that are waiting for link resolution to - // complete. - linkResQueue packetsPendingLinkResolution + // forwarder holds the packets that wait for their link-address resolutions + // to complete, and forwards them when each resolution is done. + forwarder *forwardQueue // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. @@ -550,8 +550,8 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := t.NetProto +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: netProto = header.IPv4ProtocolNumber @@ -565,7 +565,7 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } } - switch len(t.ID.LocalAddress) { + switch len(e.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState @@ -577,8 +577,8 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } switch { - case netProto == t.NetProto: - case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: + case netProto == e.NetProto: + case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: if v6only { return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute } @@ -640,6 +640,7 @@ func New(opts Options) *Stack { useNeighborCache: opts.UseNeighborCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, + forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, @@ -652,7 +653,6 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } - s.linkResQueue.init() // Add specified network protocols. for _, netProtoFactory := range opts.NetworkProtocols { @@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } -// GetLinkEndpointByName gets the link endpoint specified by name. -func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { s.mu.RLock() defer s.mu.RUnlock() for _, nic := range s.nics { if nic.Name() == name { - return nic.LinkEndpoint + return nic, true } } - return nil + return nil, false } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { } nics[id] = NICInfo{ Name: nic.name, - LinkAddress: nic.LinkEndpoint.LinkAddress(), + LinkAddress: nic.linkEP.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, - MTU: nic.LinkEndpoint.MTU(), + MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, - ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics @@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // Neighbors returns all IP to MAC address associations. @@ -1539,7 +1539,7 @@ func (s *Stack) Wait() { s.mu.RLock() defer s.mu.RUnlock() for _, n := range s.nics { - n.LinkEndpoint.Wait() + n.linkEP.Wait() } } @@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t // Add our own fake ethernet header. ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), + SrcAddr: nic.linkEP.LinkAddress(), DstAddr: dst, Type: netProto, } @@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t vv := buffer.View(fakeHeader).ToVectorisedView() vv.Append(payload) - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { + if err := nic.linkEP.WriteRawPacket(vv); err != nil { return err } @@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) return tcpip.ErrUnknownDevice } - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { + if err := nic.linkEP.WriteRawPacket(payload); err != nil { return err } |