diff options
Diffstat (limited to 'pkg/tcpip/network')
24 files changed, 13324 insertions, 1930 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 6a4839fb8..c118a2929 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -9,13 +9,17 @@ go_test( "ip_test.go", ], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", ], diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index eddf7b725..8a6bcfc2c 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/stack", ], ) @@ -28,5 +29,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9d0797af7..a79379abb 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -15,20 +15,16 @@ // Package arp implements the ARP network protocol. It is used to resolve // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. -// -// To use it in the networking stack, pass arp.NewProtocol() as one of the -// network protocols when calling stack.New. Then add an "arp" address to every -// NIC on the stack that should respond to ARP requests. That is: -// -// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { -// // handle err -// } package arp import ( + "fmt" + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -40,53 +36,81 @@ const ( ProtocolAddress = tcpip.Address("arp") ) -// endpoint implements stack.NetworkEndpoint. +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - protocol *protocol - nicID tcpip.NICID - linkEP stack.LinkEndpoint + stack.AddressableEndpointState + + protocol *protocol + + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + nic stack.NetworkInterface linkAddrCache stack.LinkAddressCache + nud stack.NUDHandler } -// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. -func (e *endpoint) DefaultTTL() uint8 { - return 0 +func (e *endpoint) Enable() *tcpip.Error { + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + e.setEnabled(true) + return nil } -func (e *endpoint) MTU() uint32 { - lmtu := e.linkEP.MTU() - return lmtu - uint32(e.MaxHeaderLength()) +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() } -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 } -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() +// setEnabled sets the enabled status for the endpoint. +func (e *endpoint) setEnabled(v bool) { + if v { + atomic.StoreUint32(&e.enabled, 1) + } else { + atomic.StoreUint32(&e.enabled, 0) + } } -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &stack.NetworkEndpointID{ProtocolAddress} +func (e *endpoint) Disable() { + e.setEnabled(false) } -func (e *endpoint) PrefixLen() int { +// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. +func (e *endpoint) DefaultTTL() uint8 { return 0 } +func (e *endpoint) MTU() uint32 { + lmtu := e.nic.MTU() + return lmtu - uint32(e.MaxHeaderLength()) +} + func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.ARPSize + return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return e.protocol.Number() + return ProtocolNumber } // WritePackets implements stack.NetworkEndpoint.WritePackets. @@ -94,16 +118,16 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { return } - h := header.ARP(v) + + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return } @@ -111,30 +135,80 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { - return // we have no useful answer, ignore the request + + if e.nud == nil { + if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + } else { + if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + + remoteAddr := tcpip.Address(h.ProtocolAddressSender()) + remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - packet := header.ARP(hdr.Prepend(header.ARPSize)) + + respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, + }) + packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ - Header: hdr, - }) - fallthrough // also fill the cache from requests + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress()) + if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + origSender := h.HardwareAddressSender() + if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize)) + } + if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) + case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + + if e.nud == nil { + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) + return + } + + // The solicited, override, and isRouter flags are not available for ARP; + // they are only available for IPv6 Neighbor Advertisements. + e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + // Solicited and unsolicited (also referred to as gratuitous) ARP Replies + // are handled equivalently to a solicited Neighbor Advertisement. + Solicited: true, + // If a different link address is received than the one cached, the entry + // should always go to Stale. + Override: false, + // ARP does not distinguish between router and non-router hosts. + IsRouter: false, + }) } } // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { + stack *stack.Stack } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -146,16 +220,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - if addrWithPrefix.Address != ProtocolAddress { - return nil, tcpip.ErrBadLocalAddress - } - return &endpoint{ +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ protocol: p, - nicID: nicID, - linkEP: sender, + nic: nic, linkAddrCache: linkAddrCache, - }, nil + nud: nud, + } + e.AddressableEndpointState.Init(e) + return e } // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. @@ -164,28 +237,50 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { - r := &stack.Route{ - RemoteLinkAddress: broadcastMAC, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + if len(remoteLinkAddr) == 0 { + remoteLinkAddr = header.EthernetBroadcastAddress } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) - h := header.ARP(hdr.Prepend(header.ARPSize)) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), linkEP.LinkAddress()) - copy(h.ProtocolAddressSender(), localAddr) - copy(h.ProtocolAddressTarget(), addr) + nicID := nic.ID() + if len(localAddr) == 0 { + addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if err != nil { + return err + } + + if len(addr.Address) == 0 { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addr.Address + } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress + } - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ - Header: hdr, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, }) + h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + pkt.NetworkProtocolNumber = ProtocolNumber + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { - return broadcastMAC, true + return header.EthernetBroadcastAddress, true } if header.IsV4MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv4Address(addr), true @@ -194,12 +289,12 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } // SetOption implements stack.NetworkProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.NetworkProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -209,9 +304,16 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + return 0, false, parse.ARP(pkt) +} // NewProtocol returns an ARP network protocol. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} +// +// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" +// address must be added to every NIC that should respond to ARP requests. See +// ProtocolAddress for more details. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 1646d9cde..bf1292bb8 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -16,10 +16,13 @@ package arp_test import ( "context" + "fmt" "strconv" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -32,54 +35,184 @@ import ( ) const ( + nicID = 1 + + stackAddr = tcpip.Address("\x0a\x00\x00\x01") stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + + remoteAddr = tcpip.Address("\x0a\x00\x00\x02") + remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + + unknownAddr = tcpip.Address("\x0a\x00\x00\x03") + + defaultChannelSize = 1 + defaultMTU = 65536 + + // eventChanSize defines the size of event channels used by the neighbor + // cache's event dispatcher. The size chosen here needs to be sufficient to + // queue all the events received during tests before consumption. + // If eventChanSize is too small, the tests may deadlock. + eventChanSize = 32 +) + +type eventType uint8 + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved ) +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + entry stack.NeighborEntry +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) +} + +// arpDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type arpDispatcher struct { + // C is where events are queued + C chan eventInfo +} + +var _ stack.NUDDispatcher = (*arpDispatcher)(nil) + +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: entry, + } + d.C <- e +} + +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + entry: entry, + } + d.C <- e +} + +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + entry: entry, + } + d.C <- e +} + +func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { + select { + case got := <-d.C: + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + return fmt.Errorf("got invalid event (-got +want):\n%s", diff) + } + case <-ctx.Done(): + return fmt.Errorf("%s for %s", ctx.Err(), want) + } + return nil +} + +func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.waitForEvent(ctx, want) +} + +func (d *arpDispatcher) nextEvent() (eventInfo, bool) { + select { + case event := <-d.C: + return event, true + default: + return eventInfo{}, false + } +} + type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack + s *stack.Stack + linkEP *channel.Endpoint + nudDisp *arpDispatcher } -func newTestContext(t *testing.T) *testContext { +func newTestContext(t *testing.T, useNeighborCache bool) *testContext { + c := stack.DefaultNUDConfigurations() + // Transition from Reachable to Stale almost immediately to test if receiving + // probes refreshes positive reachability. + c.BaseReachableTime = time.Microsecond + + d := arpDispatcher{ + // Create an event channel large enough so the neighbor cache doesn't block + // while dispatching events. Blocking could interfere with the timing of + // NUD transitions. + C: make(chan eventInfo, eventChanSize), + } + s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + NUDConfigs: c, + NUDDisp: &d, + UseNeighborCache: useNeighborCache, }) - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNIC(nicID, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress for ipv4 failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) + if !useNeighborCache { + // The remote address needs to be assigned to the NIC so we can receive and + // verify outgoing ARP packets. The neighbor cache isn't concerned with + // this; the tests that use linkAddrCache expect the ARP responses to be + // received by the same NIC. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { t.Fatalf("AddAddress for arp failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, - NIC: 1, + NIC: nicID, }}) return &testContext{ - t: t, - s: s, - linkEP: ep, + s: s, + linkEP: ep, + nudDisp: &d, } } @@ -88,7 +221,7 @@ func (c *testContext) cleanup() { } func TestDirectRequest(t *testing.T) { - c := newTestContext(t) + c := newTestContext(t, false /* useNeighborCache */) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -103,21 +236,21 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) } - for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { + for i, address := range []tcpip.Address{stackAddr, remoteAddr} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pi.Pkt.Header.View()) + rep := header.ARP(pi.Pkt.NetworkHeader().View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) @@ -134,7 +267,7 @@ func TestDirectRequest(t *testing.T) { }) } - inject(stackAddrBad) + inject(unknownAddr) // Sleep tests are gross, but this will only potentially flake // if there's a bug. If there is no bug this will reliably // succeed. @@ -144,3 +277,302 @@ func TestDirectRequest(t *testing.T) { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } } + +func TestDirectRequestWithNeighborCache(t *testing.T) { + c := newTestContext(t, true /* useNeighborCache */) + defer c.cleanup() + + tests := []struct { + name string + senderAddr tcpip.Address + senderLinkAddr tcpip.LinkAddress + targetAddr tcpip.Address + isValid bool + }{ + { + name: "Loopback", + senderAddr: stackAddr, + senderLinkAddr: stackLinkAddr, + targetAddr: stackAddr, + isValid: true, + }, + { + name: "Remote", + senderAddr: remoteAddr, + senderLinkAddr: remoteLinkAddr, + targetAddr: stackAddr, + isValid: true, + }, + { + name: "RemoteInvalidTarget", + senderAddr: remoteAddr, + senderLinkAddr: remoteLinkAddr, + targetAddr: unknownAddr, + isValid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Inject an incoming ARP request. + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), test.senderLinkAddr) + copy(h.ProtocolAddressSender(), test.senderAddr) + copy(h.ProtocolAddressTarget(), test.targetAddr) + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + Data: v.ToVectorisedView(), + }) + + if !test.isValid { + // No packets should be sent after receiving an invalid ARP request. + // There is no need to perform a blocking read here, since packets are + // sent in the same function that handles ARP requests. + if pkt, ok := c.linkEP.Read(); ok { + t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) + } + return + } + + // Verify an ARP response was sent. + pi, ok := c.linkEP.Read() + if !ok { + t.Fatal("expected ARP response to be sent, got none") + } + + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) + } + rep := header.ARP(pi.Pkt.NetworkHeader().View()) + if !rep.IsValid() { + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { + t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { + t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { + t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want) + } + + // Verify the sender was saved in the neighbor cache. + wantEvent := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + State: stack.Stale, + }, + } + if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { + t.Fatal(err) + } + + neighbors, err := c.s.Neighbors(nicID) + if err != nil { + t.Fatalf("c.s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff) + } + t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr) + } + neighborByAddr[n.Addr] = n + } + + neigh, ok := neighborByAddr[test.senderAddr] + if !ok { + t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr) + } + if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { + t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) + } + if got, want := neigh.State, stack.Stale; got != want { + t.Errorf("got neighbor State = %s, want = %s", got, want) + } + + // No more events should be dispatched + for { + event, ok := c.nudDisp.nextEvent() + if !ok { + break + } + t.Errorf("unexpected %s", event) + } + }) + } +} + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.LinkEndpoint + + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + +func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + + testAddr := tcpip.Address([]byte{1, 2, 3, 4}) + + tests := []struct { + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Unicast with no local address available", + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + } + } + + // We pass a test network interface to LinkAddressRequest with the same + // NIC ID and link endpoint used by the NIC we created earlier so that we + // can mock a link address request and observe the packets sent to the + // link endpoint even though the stack uses the real NIC to validate the + // local address. + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + } + + rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index d1c728ccf..47fb63290 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -29,6 +29,8 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", ], ) @@ -41,5 +43,10 @@ go_test( "reassembler_test.go", ], library = ":fragmentation", - deps = ["//pkg/tcpip/buffer"], + deps = [ + "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", + "//pkg/tcpip/network/testutil", + "@com_github_google_go_cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index f42abc4bb..936601287 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -13,32 +13,60 @@ // limitations under the License. // Package fragmentation contains the implementation of IP fragmentation. -// It is based on RFC 791 and RFC 815. +// It is based on RFC 791, RFC 815 and RFC 8200. package fragmentation import ( + "errors" "fmt" "log" "time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. -const DefaultReassembleTimeout = 30 * time.Second +const ( + // HighFragThreshold is the threshold at which we start trimming old + // fragmented packets. Linux uses a default value of 4 MB. See + // net.ipv4.ipfrag_high_thresh for more information. + HighFragThreshold = 4 << 20 // 4MB -// HighFragThreshold is the threshold at which we start trimming old -// fragmented packets. Linux uses a default value of 4 MB. See -// net.ipv4.ipfrag_high_thresh for more information. -const HighFragThreshold = 4 << 20 // 4MB + // LowFragThreshold is the threshold we reach to when we start dropping + // older fragmented packets. It's important that we keep enough room for newer + // packets to be re-assembled. Hence, this needs to be lower than + // HighFragThreshold enough. Linux uses a default value of 3 MB. See + // net.ipv4.ipfrag_low_thresh for more information. + LowFragThreshold = 3 << 20 // 3MB -// LowFragThreshold is the threshold we reach to when we start dropping -// older fragmented packets. It's important that we keep enough room for newer -// packets to be re-assembled. Hence, this needs to be lower than -// HighFragThreshold enough. Linux uses a default value of 3 MB. See -// net.ipv4.ipfrag_low_thresh for more information. -const LowFragThreshold = 3 << 20 // 3MB + // minBlockSize is the minimum block size for fragments. + minBlockSize = 1 +) + +var ( + // ErrInvalidArgs indicates to the caller that that an invalid argument was + // provided. + ErrInvalidArgs = errors.New("invalid args") +) + +// FragmentID is the identifier for a fragment. +type FragmentID struct { + // Source is the source address of the fragment. + Source tcpip.Address + + // Destination is the destination address of the fragment. + Destination tcpip.Address + + // ID is the identification value of the fragment. + // + // This is a uint32 because IPv6 uses a 32-bit identification value. + ID uint32 + + // The protocol for the packet. + Protocol uint8 +} // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. @@ -46,14 +74,19 @@ type Fragmentation struct { mu sync.Mutex highLimit int lowLimit int - reassemblers map[uint32]*reassembler + reassemblers map[FragmentID]*reassembler rList reassemblerList size int timeout time.Duration + blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job } // NewFragmentation creates a new Fragmentation. // +// blockSize specifies the fragment block size, in bytes. +// // highMemoryLimit specifies the limit on the memory consumed // by the fragments stored by Fragmentation (overhead of internal data-structures // is not accounted). Fragments are dropped when the limit is reached. @@ -64,7 +97,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -73,44 +106,100 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t lowMemoryLimit = 0 } - return &Fragmentation{ - reassemblers: make(map[uint32]*reassembler), + if blockSize < minBlockSize { + blockSize = minBlockSize + } + + f := &Fragmentation{ + reassemblers: make(map[FragmentID]*reassembler), highLimit: highMemoryLimit, lowLimit: lowMemoryLimit, timeout: reassemblingTimeout, + blockSize: blockSize, + clock: clock, } + f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) + + return f } -// Process processes an incoming fragment belonging to an ID -// and returns a complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { +// Process processes an incoming fragment belonging to an ID and returns a +// complete packet and its protocol number when all the packets belonging to +// that ID have been received. +// +// [first, last] is the range of the fragment bytes. +// +// first must be a multiple of the block size f is configured with. The size +// of the fragment data must be a multiple of the block size, unless there are +// no fragments following this fragment (more set to false). +// +// proto is the protocol number marked in the fragment being processed. It has +// to be given here outside of the FragmentID struct because IPv6 should not use +// the protocol to identify a fragment. +// +// releaseCB is a callback that will run when the fragment reassembly of a +// packet is complete or cancelled. releaseCB take a a boolean argument which is +// true iff the reassembly is cancelled due to timeout. releaseCB should be +// passed only with the first fragment of a packet. If more than one releaseCB +// are passed for the same packet, only the first releaseCB will be saved for +// the packet and the succeeding ones will be dropped by running them +// immediately with a false argument. +func (f *Fragmentation) Process( + id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( + buffer.VectorisedView, uint8, bool, error) { + if first > last { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + } + + if first%f.blockSize != 0 { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + } + + fragmentSize := last - first + 1 + if more && fragmentSize%f.blockSize != 0 { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + } + + if l := vv.Size(); l < int(fragmentSize) { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + } + vv.CapLength(int(fragmentSize)) + f.mu.Lock() r, ok := f.reassemblers[id] - if ok && r.tooOld(f.timeout) { - // This is very likely to be an id-collision or someone performing a slow-rate attack. - f.release(r) - ok = false - } if !ok { - r = newReassembler(id) + r = newReassembler(id, f.clock) f.reassemblers[id] = r + wasEmpty := f.rList.Empty() f.rList.PushFront(r) + if wasEmpty { + // If we have just pushed a first reassembler into an empty list, we + // should kickstart the release job. The release job will keep + // rescheduling itself until the list becomes empty. + f.releaseReassemblersLocked() + } + } + if releaseCB != nil { + if !r.setCallback(releaseCB) { + // We got a duplicate callback. Release it immediately. + releaseCB(false /* timedOut */) + } } f.mu.Unlock() - res, done, consumed, err := r.process(first, last, more, vv) + res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() - f.release(r) + f.release(r, false /* timedOut */) f.mu.Unlock() - return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed if done { - f.release(r) + f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. @@ -120,14 +209,14 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf if tail == nil { break } - f.release(tail) + f.release(tail, false /* timedOut */) } } f.mu.Unlock() - return res, done, nil + return res, firstFragmentProto, done, nil } -func (f *Fragmentation) release(r *reassembler) { +func (f *Fragmentation) release(r *reassembler, timedOut bool) { // Before releasing a fragment we need to check if r is already marked as done. // Otherwise, we would delete it twice. if r.checkDoneOrMark() { @@ -141,4 +230,105 @@ func (f *Fragmentation) release(r *reassembler) { log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) f.size = 0 } + + r.release(timedOut) // releaseCB may run. +} + +// releaseReassemblersLocked releases already-expired reassemblers, then +// schedules the job to call back itself for the remaining reassemblers if +// any. This function must be called with f.mu locked. +func (f *Fragmentation) releaseReassemblersLocked() { + now := f.clock.NowMonotonic() + for { + // The reassembler at the end of the list is the oldest. + r := f.rList.Back() + if r == nil { + // The list is empty. + break + } + elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + if f.timeout > elapsed { + // If the oldest reassembler has not expired, schedule the release + // job so that this function is called back when it has expired. + f.releaseJob.Schedule(f.timeout - elapsed) + break + } + // If the oldest reassembler has already expired, release it. + f.release(r, true /* timedOut*/) + } +} + +// PacketFragmenter is the book-keeping struct for packet fragmentation. +type PacketFragmenter struct { + transportHeader buffer.View + data buffer.VectorisedView + reserve int + fragmentPayloadLen int + fragmentCount int + currentFragment int + fragmentOffset int +} + +// MakePacketFragmenter prepares the struct needed for packet fragmentation. +// +// pkt is the packet to be fragmented. +// +// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can +// have. +// +// reserve is the number of bytes that should be reserved for the headers in +// each generated fragment. +func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { + // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be + // repeated in each fragment. However we do not currently support any header + // of that kind yet, so the following computation is valid for both IPv4 and + // IPv6. + // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are + // supported for outbound packets, the fragmentable data should not include + // these headers. + var fragmentableData buffer.VectorisedView + fragmentableData.AppendView(pkt.TransportHeader().View()) + fragmentableData.Append(pkt.Data) + fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen + + return PacketFragmenter{ + data: fragmentableData, + reserve: reserve, + fragmentPayloadLen: int(fragmentPayloadLen), + fragmentCount: int(fragmentCount), + } +} + +// BuildNextFragment returns a packet with the payload of the next fragment, +// along with the fragment's offset, the number of bytes copied and a boolean +// indicating if there are more fragments left or not. If this function is +// called again after it indicated that no more fragments were left, it will +// panic. +// +// Note that the returned packet will not have its network and link headers +// populated, but space for them will be reserved. The transport header will be +// stored in the packet's data. +func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) { + if pf.currentFragment >= pf.fragmentCount { + panic("BuildNextFragment should not be called again after the last fragment was returned") + } + + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: pf.reserve, + }) + + // Copy data for the fragment. + copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) + + offset := pf.fragmentOffset + pf.fragmentOffset += copied + pf.currentFragment++ + more := pf.currentFragment != pf.fragmentCount + + return fragPkt, offset, copied, more +} + +// RemainingFragmentCount returns the number of fragments left to be built. +func (pf *PacketFragmenter) RemainingFragmentCount() int { + return pf.fragmentCount - pf.currentFragment } diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 72c0f53be..5dcd10730 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -15,13 +15,21 @@ package fragmentation import ( + "errors" "reflect" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" ) +// reassembleTimeout is dummy timeout used for testing, where the clock never +// advances. +const reassembleTimeout = 1 + // vv is a helper to build VectorisedView from different strings. func vv(size int, pieces ...string) buffer.VectorisedView { views := make([]buffer.View, len(pieces)) @@ -33,16 +41,18 @@ func vv(size int, pieces ...string) buffer.VectorisedView { } type processInput struct { - id uint32 + id FragmentID first uint16 last uint16 more bool + proto uint8 vv buffer.VectorisedView } type processOutput struct { - vv buffer.VectorisedView - done bool + vv buffer.VectorisedView + proto uint8 + done bool } var processTestCases = []struct { @@ -53,8 +63,8 @@ var processTestCases = []struct { { comment: "One ID", in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -62,12 +72,23 @@ var processTestCases = []struct { }, }, { + comment: "Next Header protocol mismatch", + in: []processInput{ + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "01", "23"), proto: 6, done: true}, + }, + }, + { comment: "Two IDs", in: []processInput{ - {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -81,19 +102,27 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) + firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) if err != nil { - t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) } if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) + t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)", + in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView()) } if done != c.out[i].done { - t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", + in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { + if firstFragmentProto != proto { + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", + in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) + } if _, ok := f.reassemblers[in.id]; ok { t.Errorf("Process(%d) did not remove buffer from reassemblers", i) } @@ -109,53 +138,154 @@ func TestFragmentationProcess(t *testing.T) { } func TestReassemblingTimeout(t *testing.T) { - timeout := time.Millisecond - f := NewFragmentation(1024, 512, timeout) - // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(0, 0, 0, true, vv(1, "0")) - // Sleep more than the timeout. - time.Sleep(2 * timeout) - // Send another fragment that completes a packet. - // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) - if err != nil { - t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) - } - if done { - t.Errorf("Fragmentation does not respect the reassembling timeout.") + const ( + reassemblyTimeout = time.Millisecond + protocol = 0xff + ) + + type fragment struct { + first uint16 + last uint16 + more bool + data string + } + + type event struct { + // name is a nickname of this event. + name string + + // clockAdvance is a duration to advance the clock. The clock advances + // before a fragment specified in the fragment field is processed. + clockAdvance time.Duration + + // fragment is a fragment to process. This can be nil if there is no + // fragment to process. + fragment *fragment + + // expectDone is true if the fragmentation instance should report the + // reassembly is done after the fragment is processd. + expectDone bool + + // sizeAfterEvent is the expected size of the fragmentation instance after + // the event. + sizeAfterEvent int + } + + half1 := &fragment{first: 0, last: 0, more: true, data: "0"} + half2 := &fragment{first: 1, last: 1, more: false, data: "1"} + + tests := []struct { + name string + events []event + }{ + { + name: "half1 and half2 are reassembled successfully", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2", + fragment: half2, + expectDone: true, + sizeAfterEvent: 0, + }, + }, + }, + { + name: "half1 timeout, half2 timeout", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half1 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + { + name: "half2", + fragment: half2, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half2 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + for _, event := range test.events { + clock.Advance(event.clockAdvance) + if frag := event.fragment; frag != nil { + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) + if err != nil { + t.Fatalf("%s: f.Process failed: %s", event.name, err) + } + if done != event.expectDone { + t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) + } + } + if got, want := f.size, event.sizeAfterEvent; got != want { + t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want) + } + } + }) } } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(3, 1, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send first fragment with id = 1. - f.Process(1, 0, 0, true, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) // Send first fragment with id = 2. - f.Process(2, 0, 0, true, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(3, 0, 0, true, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) - if _, ok := f.reassemblers[0]; ok { + if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") } - if _, ok := f.reassemblers[1]; ok { + if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok { t.Errorf("Memory limits are not respected: id=1 has not been evicted.") } - if _, ok := f.reassemblers[3]; !ok { + if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok { t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") } } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(1, 0, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send the same packet again. - f.Process(0, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) got := f.size want := 1 @@ -163,3 +293,293 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } + +func TestErrors(t *testing.T) { + tests := []struct { + name string + blockSize uint16 + first uint16 + last uint16 + more bool + data string + err error + }{ + { + name: "exact block size without more", + blockSize: 2, + first: 2, + last: 3, + more: false, + data: "01", + }, + { + name: "exact block size with more", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "01", + }, + { + name: "exact block size with more and extra data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "012", + }, + { + name: "exact block size with more and too little data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size with more", + blockSize: 2, + first: 2, + last: 2, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size without more", + blockSize: 2, + first: 2, + last: 2, + more: false, + data: "0", + }, + { + name: "first not a multiple of block size", + blockSize: 2, + first: 3, + last: 4, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + { + name: "first more than last", + blockSize: 2, + first: 4, + last: 3, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) + if !errors.Is(err, test.err) { + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) + } + if done { + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) + } + }) + } +} + +type fragmentInfo struct { + remaining int + copied int + offset int + more bool +} + +func TestPacketFragmenter(t *testing.T) { + const ( + reserve = 60 + proto = 0 + ) + + tests := []struct { + name string + fragmentPayloadLen uint32 + transportHeaderLen int + payloadSize int + wantFragments []fragmentInfo + }{ + { + name: "Packet exactly fits in MTU", + fragmentPayloadLen: 1280, + transportHeaderLen: 0, + payloadSize: 1280, + wantFragments: []fragmentInfo{ + {remaining: 0, copied: 1280, offset: 0, more: false}, + }, + }, + { + name: "Packet exactly does not fit in MTU", + fragmentPayloadLen: 1000, + transportHeaderLen: 0, + payloadSize: 1001, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 1000, offset: 0, more: true}, + {remaining: 0, copied: 1, offset: 1000, more: false}, + }, + }, + { + name: "Packet has a transport header", + fragmentPayloadLen: 560, + transportHeaderLen: 40, + payloadSize: 560, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 560, offset: 0, more: true}, + {remaining: 0, copied: 40, offset: 560, more: false}, + }, + }, + { + name: "Packet has a huge transport header", + fragmentPayloadLen: 500, + transportHeaderLen: 1300, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {remaining: 3, copied: 500, offset: 0, more: true}, + {remaining: 2, copied: 500, offset: 500, more: true}, + {remaining: 1, copied: 500, offset: 1000, more: true}, + {remaining: 0, copied: 300, offset: 1500, more: false}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) + var originalPayload buffer.VectorisedView + originalPayload.AppendView(pkt.TransportHeader().View()) + originalPayload.Append(pkt.Data) + var reassembledPayload buffer.VectorisedView + pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) + for i := 0; ; i++ { + fragPkt, offset, copied, more := pf.BuildNextFragment() + wantFragment := test.wantFragments[i] + if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { + t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) + } + if copied != wantFragment.copied { + t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) + } + if offset != wantFragment.offset { + t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) + } + if more != wantFragment.more { + t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) + } + if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) + } + if got := fragPkt.AvailableHeaderBytes(); got != reserve { + t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) + } + if got := fragPkt.TransportHeader().View().Size(); got != 0 { + t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) + } + reassembledPayload.Append(fragPkt.Data) + if !more { + if i != len(test.wantFragments)-1 { + t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) + } + break + } + } + if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { + t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestReleaseCallback(t *testing.T) { + const ( + proto = 99 + ) + + var result int + var callbackReasonIsTimeout bool + cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + + tests := []struct { + name string + callbacks []func(bool) + timeout bool + wantResult int + wantCallbackReasonIsTimeout bool + }{ + { + name: "callback runs on release", + callbacks: []func(bool){cb1}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "first callback is nil", + callbacks: []func(bool){nil, cb2}, + timeout: false, + wantResult: 2, + wantCallbackReasonIsTimeout: false, + }, + { + name: "two callbacks - first one is set", + callbacks: []func(bool){cb1, cb2}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "callback runs on timeout", + callbacks: []func(bool){cb1}, + timeout: true, + wantResult: 1, + wantCallbackReasonIsTimeout: true, + }, + { + name: "no callbacks", + callbacks: []func(bool){nil}, + timeout: false, + wantResult: 0, + wantCallbackReasonIsTimeout: false, + }, + } + + id := FragmentID{ID: 0} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result = 0 + callbackReasonIsTimeout = false + + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + + for i, cb := range test.callbacks { + _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) + if err != nil { + t.Errorf("f.Process error = %s", err) + } + } + + r, ok := f.reassemblers[id] + if !ok { + t.Fatalf("Reassemberr not found") + } + f.release(r, test.timeout) + + if result != test.wantResult { + t.Errorf("got result = %d, want = %d", result, test.wantResult) + } + if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { + t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 0a83d81f2..c0cc0bde0 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -32,23 +32,24 @@ type hole struct { type reassembler struct { reassemblerEntry - id uint32 + id FragmentID size int + proto uint8 mu sync.Mutex holes []hole deleted int heap fragHeap done bool - creationTime time.Time + creationTime int64 + callback func(bool) } -func newReassembler(id uint32) *reassembler { +func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, holes: make([]hole, 0, 16), - deleted: 0, heap: make(fragHeap, 0, 8), - creationTime: time.Now(), + creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, @@ -78,7 +79,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,7 +87,18 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, false, consumed, nil + return buffer.VectorisedView{}, 0, false, consumed, nil + } + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded + // here (instead of just the protocol) because most IP options should be + // derived from the first fragment. + if first == 0 { + r.proto = proto } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. @@ -96,17 +108,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.VectorisedView{}, false, consumed, nil + return buffer.VectorisedView{}, 0, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) + return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err) } - return res, true, consumed, nil -} - -func (r *reassembler) tooOld(timeout time.Duration) bool { - return time.Now().Sub(r.creationTime) > timeout + return res, r.proto, true, consumed, nil } func (r *reassembler) checkDoneOrMark() bool { @@ -116,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } + +func (r *reassembler) setCallback(c func(bool)) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.callback != nil { + return false + } + r.callback = c + return true +} + +func (r *reassembler) release(timedOut bool) { + r.mu.Lock() + callback := r.callback + r.callback = nil + r.mu.Unlock() + + if callback != nil { + callback(timedOut) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index 7eee0710d..fa2a70dc8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -18,6 +18,8 @@ import ( "math" "reflect" "testing" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) type updateHolesInput struct { @@ -94,7 +96,7 @@ var holesTestCases = []struct { func TestUpdateHoles(t *testing.T) { for _, c := range holesTestCases { - r := newReassembler(0) + r := newReassembler(FragmentID{}, &faketime.NullClock{}) for _, i := range c.in { r.updateHoles(i.first, i.last, i.more) } @@ -103,3 +105,26 @@ func TestUpdateHoles(t *testing.T) { } } } + +func TestSetCallback(t *testing.T) { + result := 0 + reasonTimeout := false + + cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } + + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + if !r.setCallback(cb1) { + t.Errorf("setCallback failed") + } + if r.setCallback(cb2) { + t.Errorf("setCallback should fail if one is already set") + } + r.release(true) + if result != 1 { + t.Errorf("got result = %d, want = 1", result) + } + if !reasonTimeout { + t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4c20301c6..969579601 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,34 +15,48 @@ package ip_test import ( + "strings" "testing" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( - localIpv4Addr = "\x0a\x00\x00\x01" - localIpv4PrefixLen = 24 - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - localIpv6PrefixLen = 120 - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIPv4Addr = "\x0a\x00\x00\x01" + remoteIPv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + nicID = 1 ) +var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv4Addr, + PrefixLen: 24, +} + +var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv6Addr, + PrefixLen: 120, +} + // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. // The former is used to pretend that it's a link endpoint so that we can // inspect packets written by the network endpoints. The latter is used to @@ -96,15 +110,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ + return stack.TransportPacketHandled } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,19 +165,19 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(pkt.Header.View()) + h := header.IPv4(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(pkt.Header.View()) + h := header.IPv6(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() @@ -172,60 +187,345 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv4.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv6.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStack() *stack.Stack { - return stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, +func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) { + t.Helper() + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, }) + e := channel.New(0, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) + } + + v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) + } + + return s, e +} + +func buildDummyStack(t *testing.T) *stack.Stack { + t.Helper() + + s, _ := buildDummyStackWithLinkEndpoint(t) + return s +} + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + testObject + + mu struct { + sync.RWMutex + disabled bool + } +} + +func (*testInterface) ID() tcpip.NICID { + return nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (t *testInterface) Enabled() bool { + t.mu.RLock() + defer t.mu.RUnlock() + return !t.mu.disabled +} + +func (t *testInterface) setEnabled(v bool) { + t.mu.Lock() + defer t.mu.Unlock() + t.mu.disabled = !v +} + +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func TestSourceAddressValidation(t *testing.T) { + rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + srcAddress tcpip.Address + rxICMP func(*channel.Endpoint, tcpip.Address) + valid bool + }{ + { + name: "IPv4 valid", + srcAddress: "\x01\x02\x03\x04", + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 valid", + srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 multicast", + srcAddress: "\xe0\x00\x00\x01", + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv6 multicast", + srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + rxICMP: rxIPv6ICMP, + valid: false, + }, + { + name: "IPv4 broadcast", + srcAddress: header.IPv4Broadcast, + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv4 subnet broadcast", + srcAddress: func() tcpip.Address { + subnet := localIPv4AddrWithPrefix.Subnet() + return subnet.Broadcast() + }(), + rxICMP: rxIPv4ICMP, + valid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := buildDummyStackWithLinkEndpoint(t) + test.rxICMP(e, test.srcAddress) + + var wantValid uint64 + if test.valid { + wantValid = 1 + } + + if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { + t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { + t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) + } + }) + } +} + +func TestEnableWhenNICDisabled(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protocolFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protocolFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var nic testInterface + nic.setEnabled(false) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, + }) + p := s.NetworkProtocolInstance(test.protoNum) + + // We pass nil for all parameters except the NetworkInterface and Stack + // since Enable only depends on these. + ep := p.NewEndpoint(&nic, nil, nil, nil) + + // The endpoint should initially be disabled, regardless the NIC's enabled + // status. + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Attempting to enable the endpoint while the NIC is disabled should + // fail. + nic.setEnabled(false) + if err := ep.Enable(); err != tcpip.ErrNotPermitted { + t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted) + } + // ep should consider the NIC's enabled status when determining its own + // enabled status so we "enable" the NIC to read just the endpoint's + // enabled status. + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Enabling the interface after the NIC has been enabled should succeed. + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + if !ep.Enabled() { + t.Fatal("got ep.Enabled() = false, want = true") + } + + // ep should consider the NIC's enabled status when determining its own + // enabled status. + nic.setEnabled(false) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Disabling the endpoint when the NIC is enabled should make the endpoint + // disabled. + nic.setEnabled(true) + ep.Disable() + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + }) + } } func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: true, + }, } + ep := proto.NewEndpoint(&nic, nil, nil, nil) + defer ep.Close() // Allocate and initialize the payload view. payload := buffer.NewView(100) @@ -233,33 +533,45 @@ func TestIPv4Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv4Addr + nic.testObject.dstAddr = remoteIPv4Addr + nic.testObject.contents = payload - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } totalLen := header.IPv4MinimumSize + 30 @@ -270,9 +582,10 @@ func TestIPv4Receive(t *testing.T) { TotalLength: uint16(totalLen), TTL: 20, Protocol: 10, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -280,20 +593,24 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: view.ToVectorisedView(), }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -303,7 +620,7 @@ func TestIPv4ReceiveControl(t *testing.T) { name string expectedCount int fragmentOffset uint16 - code uint8 + code header.ICMPv4Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -317,20 +634,26 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") + r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") if err != nil { t.Fatal(err) } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + testObject: testObject{ + t: t, + }, } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize view := buffer.NewView(dataOffset + 8) @@ -342,8 +665,9 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: uint8(header.ICMPv4ProtocolNumber), SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIpv4Addr, + DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Create the ICMP header. icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) @@ -360,41 +684,51 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: c.fragmentOffset, - SrcAddr: localIpv4Addr, - DstAddr: remoteIpv4Addr, + SrcAddr: localIPv4Addr, + DstAddr: remoteIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { view[i] = uint8(i) } + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) + // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra - - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: vv, - }) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra + + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } }) } } func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} - proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } totalLen := header.IPv4MinimumSize + 24 @@ -408,9 +742,11 @@ func TestIPv4FragmentationReceive(t *testing.T) { Protocol: 10, FragmentOffset: 0, Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) + ip1.SetChecksum(^ip1.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag1[i] = uint8(i) @@ -424,48 +760,65 @@ func TestIPv4FragmentationReceive(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: 24, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) + ip2.SetChecksum(^ip2.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag2[i] = uint8(i) } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } // Send first segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), }) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.testObject.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } // Send second segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag2.ToVectorisedView(), }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } func TestIPv6Send(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + testObject: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } // Allocate and initialize the payload view. @@ -474,33 +827,44 @@ func TestIPv6Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv6Addr + nic.testObject.dstAddr = remoteIPv6Addr + nic.testObject.contents = payload - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + testObject: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } totalLen := header.IPv6MinimumSize + 30 @@ -510,8 +874,8 @@ func TestIPv6Receive(t *testing.T) { PayloadLength: uint16(totalLen - header.IPv6MinimumSize), NextHeader: 10, HopLimit: 20, - SrcAddr: remoteIpv6Addr, - DstAddr: localIpv6Addr, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, }) // Make payload be non-zero. @@ -520,21 +884,25 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: view.ToVectorisedView(), }) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -548,7 +916,7 @@ func TestIPv6ReceiveControl(t *testing.T) { expectedCount int fragmentOffset *uint16 typ header.ICMPv6Type - code uint8 + code header.ICMPv6Code expectedTyp stack.ControlType expectedExtra uint32 trunc int @@ -565,7 +933,7 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } r, err := buildIPv6Route( - localIpv6Addr, + localIPv6Addr, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", ) if err != nil { @@ -573,15 +941,20 @@ func TestIPv6ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} - proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + s := buildDummyStack(t) + proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) + nic := testInterface{ + testObject: testObject{ + t: t, + }, } - + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize @@ -595,7 +968,7 @@ func TestIPv6ReceiveControl(t *testing.T) { NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, SrcAddr: outerSrcAddr, - DstAddr: localIpv6Addr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -611,8 +984,8 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: 100, NextHeader: 10, HopLimit: 20, - SrcAddr: localIpv6Addr, - DstAddr: remoteIpv6Addr, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, }) // Build the fragmentation header if needed. @@ -634,21 +1007,435 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: view[:len(view)-c.trunc].ToVectorisedView(), - }) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) + } + }) + } +} + +// truncatedPacket returns a PacketBuffer based on a truncated view. If view, +// after truncation, is large enough to hold a network header, it makes part of +// view the packet's NetworkHeader and the rest its Data. Otherwise all of view +// becomes Data. +func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { + v := view[:len(view)-trunc] + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + _, _ = pkt.NetworkHeader().Consume(netHdrLen) + return pkt +} + +func TestWriteHeaderIncludedPacket(t *testing.T) { + const ( + nicID = 1 + transportProto = 5 + + dataLen = 4 + optionsLen = 4 + ) + + dataBuf := [dataLen]byte{1, 2, 3, 4} + data := dataBuf[:] + + ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1} + ipv4Options := ipv4OptionsBuf[:] + + ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} + ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] + + var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte + ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] + if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + nicAddr tcpip.Address + remoteAddr tcpip.Address + pktGen func(*testing.T, tcpip.Address) buffer.View + checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) + expectedErr *tcpip.Error + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with IHL too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 1, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 minimum size", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(header.IPv4MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv4 with options", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) + totalLen := ipHdrLen + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(ipHdrLen)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHdrLen), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) + } + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6 with extension header", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), + checker.IPPayload(ipv6PayloadWithExtHdr), + ) + }, + }, + { + name: "IPv6 minimum size", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(header.IPv6MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv6 too small", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + srcAddr tcpip.Address + }{ + { + name: "unspecified source", + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + }, + { + name: "random source", + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + e := channel.New(1, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) + + r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + } + defer r.Release() + + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + })); err != test.expectedErr { + t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + test.checker(t, pkt.Pkt, subTest.srcAddr) + }) } }) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 78420d6e6..6252614ec 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,9 +10,11 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", @@ -26,14 +28,19 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..cf287446e 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,16 +15,20 @@ package ipv4 import ( + "errors" + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// handleControl handles the case when an ICMP packet contains the headers of -// the original packet that caused the ICMP one to be sent. This information is -// used to find out which transport endpoint must be notified about the ICMP -// packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +// handleControl handles the case when an ICMP error packet contains the headers +// of the original packet that caused the ICMP one to be sent. This information +// is used to find out which transport endpoint must be notified about the ICMP +// packet. We only expect the payload, not the enclosing ICMP packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -37,8 +41,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // false. // // Drop packet if it doesn't have the basic IPv4 header or if the - // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -53,12 +58,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { received.Invalid.Increment() @@ -66,47 +74,142 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { } h := header.ICMPv4(v) + // Only do in-stack processing if the checksum is correct. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + received.Invalid.Increment() + // It's possible that a raw socket expects to receive this regardless + // of checksum errors. If it's an echo request we know it's safe because + // we are the only handler, however other types do not cope well with + // packets with checksum errors. + switch h.Type() { + case header.ICMPv4Echo: + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + } + return + } + + iph := header.IPv4(pkt.NetworkHeader().View()) + var newOptions header.IPv4Options + if len(iph) > header.IPv4MinimumSize { + // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip + // type ICMP packets): + // If a Record Route and/or Time Stamp option is received in an + // ICMP Echo Request, this option (these options) SHOULD be + // updated to include the current host and included in the IP + // header of the Echo Reply message, without "truncation". + // Thus, the recorded route will be for the entire round trip. + // + // So we need to let the option processor know how it should handle them. + var op optionsUsage + if h.Type() == header.ICMPv4Echo { + op = &optionUsageEcho{} + } else { + op = &optionUsageReceive{} + } + aux, tmp, err := processIPOptions(r, iph.Options(), op) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + } + return + } + newOptions = tmp + } + // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - // Only send a reply if the checksum is valid. - wantChecksum := h.Checksum() - // Reset the checksum field to 0 to can calculate the proper - // checksum. We'll have to reset this before we hand the packet - // off. - h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - if gotChecksum != wantChecksum { - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - received.Invalid.Increment() + sent := stats.ICMP.V4PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() return } + // DeliverTransportPacket will take ownership of pkt so don't use it beyond + // this point. Make a deep copy of the data before pkt gets sent as we will + // be modifying fields. + // + // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no + // waiting endpoints. Consider moving responsibility for doing the copy to + // DeliverTransportPacket so that is is only done when needed. + replyData := pkt.Data.ToOwnedView() + // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ - Data: pkt.Data.Clone(nil), - NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + pkt = nil + // Take the base of the incoming request IP header but replace the options. + replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) + replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) + replyIPHdr.SetHeaderLength(replyHeaderLength) + + // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP + // source address MUST be one of its own IP addresses (but not a broadcast + // or multicast address). + localAddr := r.LocalAddress + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the + // header information, we may have to change this code to handle the + // ICMP header no longer being in the data buffer. + + // Because IP and ICMP are so closely intertwined, we need to handcraft our + // IP header to be able to follow RFC 792. The wording on page 13 is as + // follows: + // IP Fields: + // Addresses + // The address of the source in an echo message will be the + // destination of the echo reply message. To form an echo reply + // message, the source and destination addresses are simply reversed, + // the type code changed to 0, and the checksum recomputed. + // + // This was interpreted by early implementors to mean that all options must + // be copied from the echo request IP header to the echo reply IP header + // and this behaviour is still relied upon by some applications. + // + // Create a copy of the IP header we received, options and all, and change + // The fields we need to alter. + // + // We need to produce the entire packet in the data segment in order to + // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the + // total length and the header checksum so we don't need to set those here. + replyIPHdr.SetSourceAddress(r.LocalAddress) + replyIPHdr.SetDestinationAddress(r.RemoteAddress) + replyIPHdr.SetTTL(r.DefaultTTL()) + + replyICMPHdr := header.ICMPv4(replyData) + replyICMPHdr.SetType(header.ICMPv4EchoReply) + replyICMPHdr.SetChecksum(0) + replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0)) + + replyVV := buffer.View(replyIPHdr).ToVectorisedView() + replyVV.AppendView(replyData) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: replyVV, }) + replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - vv := pkt.Data.Clone(nil) - vv.TrimFront(header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv4EchoReply) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) - sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: vv, - TransportHeader: buffer.View(pkt), - }); err != nil { + if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return } @@ -122,12 +225,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { + case header.ICMPv4HostUnreachable: + e.handleControl(stack.ControlNoRoute, 0, pkt) + case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: - mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) } case header.ICMPv4SrcQuench: @@ -158,3 +267,217 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { received.Invalid.Increment() } } + +// ======= ICMP Error packet generation ========= + +// icmpReason is a marker interface for IPv4 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// icmpReasonProtoUnreachable is an error where the transport protocol is +// not supported. +type icmpReasonProtoUnreachable struct{} + +func (*icmpReasonProtoUnreachable) isICMPReason() {} + +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + +// icmpReasonParamProblem is an error to use to request a Parameter Problem +// message to be sent. +type icmpReasonParamProblem struct { + pointer byte +} + +func (*icmpReasonParamProblem) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv4 and sends it back to the remote device that sent +// the problematic packet. It incorporates as much of that packet as +// possible as well as any error metadata as is available. returnError +// expects pkt to hold a valid IPv4 packet as per the wire format. +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + // We check we are responding only when we are allowed to. + // See RFC 1812 section 4.3.2.7 (shown below). + // + // ========= + // 4.3.2.7 When Not to Send ICMP Errors + // + // An ICMP error message MUST NOT be sent as the result of receiving: + // + // o An ICMP error message, or + // + // o A packet which fails the IP header validation tests described in + // Section [5.2.2] (except where that section specifically permits + // the sending of an ICMP error message), or + // + // o A packet destined to an IP broadcast or IP multicast address, or + // + // o A packet sent as a Link Layer broadcast or multicast, or + // + // o Any fragment of a datagram other then the first fragment (i.e., a + // packet for which the fragment offset in the IP header is nonzero). + // + // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in + // response to a non-initial fragment, but it currently can not happen. + + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + return nil + } + + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + sent := p.stack.Stats().ICMP.V4PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + networkHeader := pkt.NetworkHeader().View() + transportHeader := pkt.TransportHeader().View() + + // Don't respond to icmp error packets. + if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + // TODO(gvisor.dev/issue/3810): + // Unfortunately the current stack pretty much always has ICMPv4 headers + // in the Data section of the packet but there is no guarantee that is the + // case. If this is the case grab the header to make it like all other + // packet types. When this is cleaned up the Consume should be removed. + if transportHeader.IsEmpty() { + var ok bool + transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize) + if !ok { + return nil + } + } else if transportHeader.Size() < header.ICMPv4MinimumSize { + return nil + } + // We need to decide to explicitly name the packets we can respond to or + // the ones we can not respond to. The decision is somewhat arbitrary and + // if problems arise this could be reversed. It was judged less of a breach + // of protocol to not respond to unknown non-error packets than to respond + // to unknown error packets so we take the first approach. + switch header.ICMPv4(transportHeader).Type() { + case + header.ICMPv4EchoReply, + header.ICMPv4Echo, + header.ICMPv4Timestamp, + header.ICMPv4TimestampReply, + header.ICMPv4InfoRequest, + header.ICMPv4InfoReply: + default: + // Assume any type we don't know about may be an error type. + return nil + } + } + + // Now work out how much of the triggering packet we should return. + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes. + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 and RFC 792 where it mentioned that at + // least 8 bytes of the payload must be included. Today linux and other + // systems implement the RFC 1812 definition and not the original + // requirement. We treat 8 bytes as the minimum but will try send more. + mtu := int(route.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + + if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { + return nil + } + + payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + + // The buffers used by pkt may be used elsewhere in the system. + // For example, an AF_RAW or AF_PACKET socket may use what the transport + // protocol considers an unreachable destination. Thus we deep copy pkt to + // prevent multiple ownership and SR errors. The new copy is a vectorized + // view with the entire incoming IP packet reassembled and truncated as + // required. This is now the payload of the new ICMP packet and no longer + // considered a packet in its own right. + newHeader := append(buffer.View(nil), networkHeader...) + newHeader = append(newHeader, transportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) + payload.CapLength(payloadLen) + + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + var counter *tcpip.StatCounter + switch reason := reason.(type) { + case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + counter = sent.DstUnreachable + case *icmpReasonProtoUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) + counter = sent.TimeExceeded + case *icmpReasonParamProblem: + icmpHdr.SetType(header.ICMPv4ParamProblem) + icmpHdr.SetCode(header.ICMPv4UnusedCode) + icmpHdr.SetPointer(reason.pointer) + counter = sent.ParamProblem + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + + if err := route.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + icmpPkt, + ); err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 64046cbbf..4592984a5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -12,26 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv4 contains the implementation of the ipv4 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv4.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv4 contains the implementation of the ipv4 network protocol. package ipv4 import ( + "errors" + "fmt" + "math" "sync/atomic" + "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( + // ReassembleTimeout is the time a packet stays in the reassembly + // system before being evicted. + // As per RFC 791 section 3.2: + // The current recommendation for the initial timer setting is 15 seconds. + // This may be changed as experience with this protocol accumulates. + // + // Considering that it is an old recommendation, we use the same reassembly + // timeout that linux defines, which is 30 seconds: + // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 + ReassembleTimeout = 30 * time.Second + // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -44,78 +56,141 @@ const ( // buckets is the number of identifier buckets. buckets = 2048 + + // The size of a fragment block, in bytes, as per RFC 791 section 3.1, + // page 14. + fragmentblockSize = 8 ) +var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() + +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int - linkEP stack.LinkEndpoint - dispatcher stack.TransportDispatcher - fragmentation *fragmentation.Fragmentation - protocol *protocol - stack *stack.Stack + nic stack.NetworkInterface + dispatcher stack.TransportDispatcher + protocol *protocol + + // enabled is set to 1 when the enpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + } } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, - dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), - protocol: p, - stack: st, + nic: nic, + dispatcher: dispatcher, + protocol: p, + } + e.mu.addressableEndpointState.Init(e) + return e +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil } - return e, nil + // Create an endpoint to receive broadcast packets on this interface. + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + return err + } + // We have no need for the address endpoint. + ep.DecRef() + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) + return err } -// DefaultTTL is the default time-to-live value for this endpoint. -func (e *endpoint) DefaultTTL() uint8 { - return e.protocol.DefaultTTL() +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. -func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 } -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() } -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) + } + + // The address may have already been removed. + if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) + } } -// ID returns the ipv4 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id +// DefaultTTL is the default time-to-live value for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return e.protocol.DefaultTTL() } -// PrefixLen returns the ipv4 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 + return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -123,113 +198,13 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in pkt.Header but does not -// assume that only the IP header is in pkt.Header. It assumes that the input -// packet's stated length matches the length of the header+payload. mtu -// includes the IP header and options. This does not support the DontFragment -// IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { - // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.Header.View()) - flags := ip.Flags() - - // Update mtu to take into account the header, which will exist in all - // fragments anyway. - innerMTU := mtu - int(ip.HeaderLength()) - - // Round the MTU down to align to 8 bytes. Then calculate the number of - // fragments. Calculate fragment sizes as in RFC791. - innerMTU &^= 7 - n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU - - outerMTU := innerMTU + int(ip.HeaderLength()) - offset := ip.FragmentOffset() - originalAvailableLength := pkt.Header.AvailableLength() - for i := 0; i < n; i++ { - // Where possible, the first fragment that is sent has the same - // pkt.Header.UsedLength() as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - h := ip - if i > 0 { - pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) - copy(h, ip[:ip.HeaderLength()]) - } - if i != n-1 { - h.SetTotalLength(uint16(outerMTU)) - h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) - } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) - h.SetFlagsFragmentOffset(flags, offset) - } - h.SetChecksum(0) - h.SetChecksum(^h.CalculateChecksum()) - offset += uint16(innerMTU) - if i > 0 { - newPayload := pkt.Data.Clone(nil) - newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayload.Size()) - continue - } - // Special handling for the first fragment because it comes - // from the header. - if outerMTU >= pkt.Header.UsedLength() { - // This fragment can fit all of pkt.Header and possibly - // some of pkt.Data, too. - newPayload := pkt.Data.Clone(nil) - newPayloadLength := outerMTU - pkt.Header.UsedLength() - newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayloadLength) - } else { - // The fragment is too small to fit all of pkt.Header. - startOfHdr := pkt.Header - startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) - emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of pkt.Header into the pkt.Data - // that remains to be sent. - restOfHdr := pkt.Header.View()[outerMTU:] - tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(pkt.Data) - pkt.Data = tmp - } - } - return nil -} - -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payloadSize) - id := uint32(0) - if length > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) - } +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + length := uint16(pkt.Size()) + // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic + // datagrams. Since the DF bit is never being set here, all datagrams + // are non-atomic and need an ID. + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, @@ -241,64 +216,96 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - return ip + pkt.NetworkProtocolNumber = ProtocolNumber +} + +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + // Round the MTU down to align to 8 bytes. + fragmentPayloadSize := networkMTU &^ 7 + networkHeader := header.IPv4(pkt.NetworkHeader().View()) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader)) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + return n, pf.RemainingFragmentCount(), nil + } + } } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { + e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt) +} - nicName := e.stack.FindNICNameFromID(e.NICID()) +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. - ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() return nil } + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. - netHeader := header.IPv4(pkt.NetworkHeader) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) return nil } } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - + e.HandlePacket(&loopedR, pkt) loopedR.Release() } if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) + + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + return err } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.Stats().IP.PacketsSent.Increment() @@ -314,26 +321,49 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } - for pkt := pkts.Front(); pkt != nil; { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt = pkt.Next() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.addIPHeader(r, pkt, params) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pkt + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(pkt, fragPkt) + pkt = fragPkt + return nil + }); err != nil { + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err)) + } + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) + } } - nicName := e.stack.FindNICNameFromID(e.NICID()) + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - // Slow Path as we are dropping some packets in the batch degrade to + // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { @@ -341,120 +371,139 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + ep.HandlePacket(&route, pkt) n++ continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err } n++ } r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, nil + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } -// WriteHeaderIncludedPacket writes a packet already containing a network -// header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrInvalidOptionValue + return tcpip.ErrMalformedHeader } ip := header.IPv4(h) - if !ip.IsValid(pkt.Data.Size()) { - return tcpip.ErrInvalidOptionValue - } // Always set the total length. - ip.SetTotalLength(uint16(pkt.Data.Size())) + pktSize := pkt.Data.Size() + ip.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + if ip.SourceAddress() == header.IPv4Any { ip.SetSourceAddress(r.LocalAddress) } - // Set the destination. If the packet already included a destination, - // it will be part of the route. + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. ip.SetDestinationAddress(r.RemoteAddress) // Set the packet ID when zero. if ip.ID() == 0 { - id := uint32(0) - if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams, so assign an ID to all such datagrams + // according to the definition given in RFC 6864 section 4. + if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } - ip.SetID(uint16(id)) } // Always set the checksum. ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if r.Loop&stack.PacketLoop != 0 { - e.HandlePacket(r, pkt.Clone()) - } - if r.Loop&stack.PacketOut == 0 { - return nil + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader } - r.Stats().IP.PacketsSent.Increment() - - ip = ip[:ip.HeaderLength()] - pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) - pkt.Data.TrimFront(int(ip.HeaderLength())) - return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return e.writePacket(r, nil /* gso */, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { + return + } + + h := header.IPv4(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv4(headerView) - if !h.IsValid(pkt.Data.Size()) { + + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { r.Stats().IP.MalformedPacketsReceived.Increment() return } - pkt.NetworkHeader = headerView[:h.HeaderLength()] - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - pkt.Data.TrimFront(hlen) - pkt.Data.CapLength(tlen - hlen) + // As per RFC 1122 section 3.2.1.3: + // When a host sends any datagram, the IP source address MUST + // be one of its own IP addresses (but not a broadcast or + // multicast address). + if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() return } - more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 - if more || h.FragmentOffset() != 0 { - if pkt.Data.Size() == 0 { + if h.More() || h.FragmentOffset() != 0 { + if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -462,18 +511,60 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < h.FragmentOffset() { + start := h.FragmentOffset() + // Drop the fragment if the size of the reassembled payload would exceed the + // maximum payload size. + // + // Note that this addition doesn't overflow even on 32bit architecture + // because pkt.Data.Size() should not exceed 65535 (the max IP datagram + // size). Otherwise the packet would've been rejected as invalid before + // reaching here. + if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return } + + // Set up a callback in case we need to send a Time Exceeded Message, as per + // RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards + // the datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + r := r.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + } + r.Release() + } + } + var ready bool var err error - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + proto := h.Protocol() + pkt.Data, _, ready, err = e.protocol.fragmentation.Process( + // As per RFC 791 section 2.3, the identification value is unique + // for a source-destination pair and protocol. + fragmentation.FragmentID{ + Source: h.SourceAddress(), + Destination: h.DestinationAddress(), + ID: uint32(h.ID()), + Protocol: proto, + }, + start, + start+uint16(pkt.Data.Size())-1, + h.More(), + proto, + pkt.Data, + releaseCB, + ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -482,28 +573,193 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { if !ready { return } + + // The reassembler doesn't take care of fixing up the header, so we need + // to do it here. + h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetFlagsFragmentOffset(0, 0) } + r.Stats().IP.PacketsDelivered.Increment() + p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - headerView.CapLength(hlen) + // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport + // headers, the setting of the transport number here should be + // unnecessary and removed. + pkt.TransportProtocolNumber = p e.handleICMP(r, pkt) return } - r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, pkt) + if len(h.Options()) != 0 { + // TODO(gvisor.dev/issue/4586): + // When we add forwarding support we should use the verified options + // rather than just throwing them away. + aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{}) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + } + return + } + } + + switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + case stack.TransportPacketHandled: + case stack.TransportPacketDestinationPortUnreachable: + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC: 1122 Section 3.2.2.1 + // A host SHOULD generate Destination Unreachable messages with code: + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported + _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) + default: + panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) + } } // Close cleans up resources associated with the endpoint. -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + e.disableLocked() + e.mu.addressableEndpointState.Cleanup() +} + +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.RemovePermanentAddress(addr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + + loopback := e.nic.IsLoopback() + addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + // IPv4 has a notion of a subnet broadcast address and considers the + // loopback interface bound to an address's whole subnet (on linux). + return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) + }) + if addressEndpoint != nil { + return addressEndpoint + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) + if err != nil { + // AddAddress only returns an error if the address is already assigned, + // but we just checked above if the address exists so we expect no error. + panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) + } + return addressEndpoint +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV4MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) type protocol struct { - ids []uint32 - hashIV uint32 + stack *stack.Stack // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. defaultTTL uint32 + + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + + ids []uint32 + hashIV uint32 + + fragmentation *fragmentation.Fragmentation } // Number returns the ipv4 protocol number. @@ -528,10 +784,10 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case tcpip.DefaultTTLOption: - p.SetDefaultTTL(uint8(v)) + case *tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(*v)) return nil default: return tcpip.ErrUnknownProtocolOption @@ -539,7 +795,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) @@ -565,28 +821,80 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + if ok := parse.IPv4(pkt); !ok { + return 0, false, false + } + + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true +} + +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + if v { + atomic.StoreUint32(&p.forwarding, 1) + } else { + atomic.StoreUint32(&p.forwarding, 0) + } +} + +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload mtu. +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv4MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState + } + + // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in + // length: + // The maximal internet header is 60 octets, and a typical internet header + // is 20 octets, allowing a margin for headers of higher level protocols. + if networkHeaderSize > header.IPv4MaximumHeaderSize { + return 0, tcpip.ErrMalformedHeader + } + + networkMTU := linkMTU + if networkMTU > MaxTotalSize { + networkMTU = MaxTotalSize } - return mtu - header.IPv4MinimumSize + + return networkMTU - uint32(networkHeaderSize), nil +} + +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU +} + +// addressToUint32 translates an IPv4 address into its little endian uint32 +// representation. +// +// This function does the same thing as binary.LittleEndian.Uint32 but operates +// on a tcpip.Address (a string) without the need to convert it to a byte slice, +// which would cause an allocation. +func addressToUint32(addr tcpip.Address) uint32 { + _ = addr[3] // bounds check hint to compiler + return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 } // hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number, and a random initial -// value (generated once on initialization) to generate the hash. +// destination address, the transport protocol number and a 32-bit number to +// generate the hash. func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - t := r.LocalAddress - a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 - t = r.RemoteAddress - b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + a := addressToUint32(r.LocalAddress) + b := addressToUint32(r.RemoteAddress) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } // NewProtocol returns an IPv4 network protocol. -func NewProtocol() stack.NetworkProtocol { +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -596,5 +904,353 @@ func NewProtocol() stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} + return &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), + } +} + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeaderLength := len(originalIPHeader) + nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + + if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) + } + + flags := originalIPHeader.Flags() + if more { + flags |= header.IPv4FlagMoreFragments + } + nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset)) + nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied)) + nextFragIPHeader.SetChecksum(0) + nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum()) + + return fragPkt, more +} + +// optionAction describes possible actions that may be taken on an option +// while processing it. +type optionAction uint8 + +const ( + // optionRemove says that the option should not be in the output option set. + optionRemove optionAction = iota + + // optionProcess says that the option should be fully processed. + optionProcess + + // optionVerify says the option should be checked and passed unchanged. + optionVerify + + // optionPass says to pass the output set without checking. + optionPass +) + +// optionActions list what to do for each option in a given scenario. +type optionActions struct { + // timestamp controls what to do with a Timestamp option. + timestamp optionAction + + // recordroute controls what to do with a Record Route option. + recordRoute optionAction + + // unknown controls what to do with an unknown option. + unknown optionAction +} + +// optionsUsage specifies the ways options may be operated upon for a given +// scenario during packet processing. +type optionsUsage interface { + actions() optionActions +} + +// optionUsageReceive implements optionsUsage for received packets. +type optionUsageReceive struct{} + +// actions implements optionsUsage. +func (*optionUsageReceive) actions() optionActions { + return optionActions{ + timestamp: optionVerify, + recordRoute: optionVerify, + unknown: optionPass, + } +} + +// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it +// is enabled (Process, Process, Pass) and for fragmenting (Process, Process, +// Pass for frag1, but Remove,Remove,Remove for all other frags). + +// optionUsageEcho implements optionsUsage for echo packet processing. +type optionUsageEcho struct{} + +// actions implements optionsUsage. +func (*optionUsageEcho) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, + unknown: optionRemove, + } +} + +var ( + errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") + errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") + errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") + errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") +) + +// handleTimestamp does any required processing on a Timestamp option +// in place. +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { + flags := tsOpt.Flags() + var entrySize uint8 + switch flags { + case header.IPv4OptionTimestampOnlyFlag: + entrySize = header.IPv4OptionTimestampSize + case + header.IPv4OptionTimestampWithIPFlag, + header.IPv4OptionTimestampWithPredefinedIPFlag: + entrySize = header.IPv4OptionTimestampWithAddrSize + default: + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + } + + pointer := tsOpt.Pointer() + // To simplify processing below, base further work on the array of timestamps + // beyond the header, rather than on the whole option. Also to aid + // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. + nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1) + optLen := tsOpt.Size() + dataLength := optLen - header.IPv4OptionTimestampHdrLength + + // In the section below, we verify the pointer, length and overflow counter + // fields of the option. The distinction is in which byte you return as being + // in error in the ICMP packet. Offsets 1 (length), 2 pointer) + // or 3 (overflowed counter). + // + // The following RFC sections cover this section: + // + // RFC 791 (page 22): + // If there is some room but not enough room for a full timestamp + // to be inserted, or the overflow count itself overflows, the + // original datagram is considered to be in error and is discarded. + // In either case an ICMP parameter problem message may be sent to + // the source host [3]. + // + // You can get this situation in two ways. Firstly if the data area is not + // a multiple of the entry size or secondly, if the pointer is not at a + // multiple of the entry size. The wording of the RFC suggests that + // this is not an error until you actually run out of space. + if pointer > optLen { + // RFC 791 (page 22) says we should switch to using the overflow count. + // If the timestamp data area is already full (the pointer exceeds + // the length) the datagram is forwarded without inserting the + // timestamp, but the overflow count is incremented by one. + if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { + // By definition we have nothing to do. + return 0, nil + } + + if tsOpt.IncOverflow() != 0 { + return 0, nil + } + // The overflow count is also full. + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + } + if nextSlot+entrySize > dataLength { + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. + if false { + // We must select Pointer for Linux compatibility, even if + // only the length is bad. + // The Linux code is at (in October 2020) + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + // which doesn't distinguish between which of optptr[2] or optlen + // is wrong, but just arbitrarily decides on optptr+2. + if dataLength%entrySize != 0 { + // The Data section size should be a multiple of the expected + // timestamp entry size. + return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + } + // If the size is OK, the pointer must be corrupted. + } + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } + + if usage.actions().timestamp == optionProcess { + tsOpt.UpdateTimestamp(localAddress, clock) + } + return 0, nil +} + +var ( + errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") + errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") +) + +// handleRecordRoute checks and processes a Record route option. It is much +// like the timestamp type 1 option, but without timestamps. The passed in +// address is stored in the option in the correct spot if possible. +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { + optlen := rrOpt.Size() + + if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + + nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + + // RFC 791 page 21 says + // If the route data area is already full (the pointer exceeds the + // length) the datagram is forwarded without inserting the address + // into the recorded route. If there is some room but not enough + // room for a full address to be inserted, the original datagram is + // considered to be in error and is discarded. In either case an + // ICMP parameter problem message may be sent to the source + // host. + // The use of the words "In either case" suggests that a 'full' RR option + // could generate an ICMP at every hop after it fills up. We chose to not + // do this (as do most implementations). It is probable that the inclusion + // of these words is a copy/paste error from the timestamp option where + // there are two failure reasons given. + if nextSlot >= optlen { + return 0, nil + } + + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. We must select Pointer for Linux + // compatibility, even if only the length is bad. + if nextSlot+header.IPv4AddressSize > optlen { + if false { + // This is what we would do if we were not being Linux compatible. + // Check for bad pointer or length value. Must be a multiple of 4 after + // accounting for the 3 byte header and not within that header. + // RFC 791, page 20 says: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // + // A recorded route is composed of a series of internet addresses. + // Each internet address is 32 bits or 4 octets. + // Linux skips this test so we must too. See Linux code at: + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { + // Length is bad, not on integral number of slots. + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + // If not length, the fault must be with the pointer. + } + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } + if usage.actions().recordRoute == optionVerify { + return 0, nil + } + rrOpt.StoreAddress(localAddress) + return 0, nil +} + +// processIPOptions parses the IPv4 options and produces a new set of options +// suitable for use in the next step of packet processing as informed by usage. +// The original will not be touched. +// +// Returns +// - The location of an error if there was one (or 0 if no error) +// - If there is an error, information as to what it was was. +// - The replacement option set. +func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + + opts := header.IPv4Options(orig) + optIter := opts.MakeIterator() + + // Each option other than NOP must only appear (RFC 791 section 3.1, at the + // definition of every type). Keep track of each of the possible types in + // the 8 bit 'type' field. + var seenOptions [math.MaxUint8 + 1]bool + + // TODO(gvisor.dev/issue/4586): + // This will need tweaking when we start really forwarding packets + // as we may need to get two addresses, for rx and tx interfaces. + // We will also have to take usage into account. + prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber) + localAddress := prefixedAddress.Address + if err != nil { + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + } + localAddress = r.LocalAddress + } + + for { + option, done, err := optIter.Next() + if done || err != nil { + return optIter.ErrCursor, optIter.Finalize(), err + } + optType := option.Type() + if optType == header.IPv4OptionNOPType { + optIter.PushNOPOrEnd(optType) + continue + } + if optType == header.IPv4OptionListEndType { + optIter.PushNOPOrEnd(optType) + return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + } + + // check for repeating options (multiple NOPs are OK) + if seenOptions[optType] { + return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + } + seenOptions[optType] = true + + optLen := int(option.Size()) + switch option := option.(type) { + case *header.IPv4OptionTimestamp: + r.Stats().IP.OptionTSReceived.Increment() + if usage.actions().timestamp != optionRemove { + clock := r.Stack().Clock() + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + case *header.IPv4OptionRecordRoute: + r.Stats().IP.OptionRRReceived.Increment() + if usage.actions().recordRoute != optionRemove { + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + default: + r.Stats().IP.OptionUnknownReceived.Increment() + if usage.actions().unknown == optionPass { + newBuffer := optIter.RemainingBuffer()[:optLen] + // Arguments already heavily checked.. ignore result. + _ = copy(newBuffer, option.Contents()) + optIter.ConsumeBuffer(optLen) + } + } + } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 36035c820..61672a5ff 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,31 +15,43 @@ package ipv4_test import ( - "bytes" + "context" "encoding/hex" - "math/rand" + "fmt" + "math" + "net" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) +const ( + extraHeaderReserve = 50 + defaultMTU = 65536 +) + func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { ep = sniffer.New(ep) @@ -91,35 +103,672 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s - } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload +// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and +// checks the response. +func TestIPv4Sanity(t *testing.T) { + const ( + ttl = 255 + nicID = 1 + randomSequence = 123 + randomIdent = 42 + ) + var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + ) + + tests := []struct { + name string + headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + options []byte + replyOptions []byte // if succeeds, reply should look like this + shouldFail bool + expectErrorICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + paramProblemPointer uint8 + }{ + { + name: "valid no options", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + }, + { + name: "bad header checksum", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + badHeaderChecksum: true, + shouldFail: true, + }, + // The TTL tests check that we are not rejecting an incoming packet + // with a zero or one TTL, which has been a point of confusion in the + // past as RFC 791 says: "If this field contains the value zero, then the + // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies + // for the case of the destination host, stating as follows. + // + // A host MUST NOT send a datagram with a Time-to-Live (TTL) + // value of zero. + // + // A host MUST NOT discard a datagram just because it was + // received with TTL less than 2. + { + name: "zero TTL", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 0, + }, + { + name: "one TTL", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + }, + { + name: "End options", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{0, 0, 0, 0}, + replyOptions: []byte{0, 0, 0, 0}, + }, + { + name: "NOP options", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 1, 1}, + replyOptions: []byte{1, 1, 1, 1}, + }, + { + name: "NOP and End options", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 0, 0}, + replyOptions: []byte{1, 1, 0, 0}, + }, + { + name: "bad header length", + headerLength: header.IPv4MinimumSize - 1, + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + }, + { + name: "bad total length (0)", + maxTotalLength: 0, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + }, + { + name: "bad total length (ip - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + }, + { + name: "bad total length (ip + icmp - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + }, + { + name: "bad protocol", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: 99, + TTL: ttl, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4DstUnreachable, + ICMPCode: header.ICMPv4ProtoUnreachable, + }, + { + name: "timestamp option overflow", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + }, + { + name: "timestamp option overflow full", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0xF1, + // ^ Counter full (15/0xF) + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + replyOptions: []byte{}, + }, + { + name: "unknown option", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{10, 4, 9, 0}, + // ^^ + // The unknown option should be stripped out of the reply. + replyOptions: []byte{}, + }, + { + name: "bad option - length 0", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 0, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + name: "bad option - length big", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 9, 9, 0, + // ^ + // There are only 8 bytes allocated to options so 9 bytes of timestamp + // space is not possible. (Second byte) + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + // This tests for some linux compatible behaviour. + // The ICMP pointer returned is 22 for Linux but the + // error is actually in spot 21. + name: "bad option - length bad", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + // Timestamps are in multiples of 4 or 8 but never 7. + // The option space should be padded out. + options: []byte{ + 68, 7, 5, 0, + // ^ ^ Linux points here which is wrong. + // | Not a multiple of 4 + 1, 2, 3, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "multiple type 0 with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 24, 21, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 24, 25, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // The timestamp area is full so add to the overflow count. + name: "multiple type 1 timestamps", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 21, 0x11, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + // Overflow count is the top nibble of the 4th byte. + replyOptions: []byte{ + 68, 20, 21, 0x21, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + }, + { + name: "multiple type 1 timestamps with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 28, 21, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 28, 29, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 192, 168, 1, 58, // New IP Address. + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // Needs 8 bytes for a type 1 timestamp but there are only 4 free. + name: "bad timer element alignment", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 17, 0x01, + // ^^ ^^ 20 byte area, next free spot at 17. + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + // End of option list with illegal option after it, which should be ignored. + { + name: "end of options list", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 10, 3, 99, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, // 3 bytes unknown option + }, // ^ End of options hides following bytes. + }, + { + // Timestamp with a size too small. + name: "timestamp truncated", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{68, 1, 0, 0}, + // ^ Smallest possible is 8. + shouldFail: true, + }, + { + name: "single record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 4, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "multiple record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 20, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "single record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Unlike timestamp, this should just succeed. + name: "multiple record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 24, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm linux bug for bug compatibility. + // Linux returns slot 22 but the error is in slot 21. + name: "multiple record route with not enough room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 8, 8, // 3 byte header + // ^ ^ Linux points here. We must too. + // | Not enough room. 1 byte free, need 4. + 1, 2, 3, 4, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + replyOptions: []byte{}, + }, + { + name: "duplicate record route", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, 0, // pad + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 7, + replyOptions: []byte{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + } + // Advance the clock by some unimportant amount to make + // sure it's all set up. + clock.Advance(time.Millisecond * 0x10203040) + + // Default routes for IPv4 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // Round up the header size to the next multiple of 4 as RFC 791, page 11 + // says: "Internet Header Length is the length of the internet header + // in 32 bit words..." and on page 23: "The internet header padding is + // used to ensure that the internet header ends on a 32 bit boundary." + ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1) + + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + + // Specify ident/seq to make sure we get the same in the response. + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + if test.maxTotalLength < totalLen { + totalLen = test.maxTotalLength + } + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHeaderLength), + TotalLength: totalLen, + Protocol: test.transportProtocol, + TTL: test.TTL, + SrcAddr: remoteIPv4Addr, + DstAddr: ipv4Addr.Address, + }) + if n := copy(ip.Options(), test.options); n != len(test.options) { + t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options)) + } + // Override the correct value if the test case specified one. + if test.headerLength != 0 { + ip.SetHeaderLength(test.headerLength) + } + ip.SetChecksum(0) + ipHeaderChecksum := ip.CalculateChecksum() + if test.badHeaderChecksum { + ipHeaderChecksum += 42 + } + ip.SetChecksum(^ipHeaderChecksum) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e.Read() + if !ok { + if test.shouldFail { + if test.expectErrorICMP { + t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) + } + return // Expected silent failure. + } + t.Fatal("expected ICMP echo reply missing") + } + + // We didn't expect a packet. Register our surprise but carry on to + // provide more information about what we got. + if test.shouldFail && !test.expectErrorICMP { + t.Error("unexpected packet response") + } + + // Check the route that brought the packet to us. + if reply.Route.LocalAddress != ipv4Addr.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) + } + if reply.Route.RemoteAddress != remoteIPv4Addr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) + } + + // Make sure it's all in one buffer for checker. + replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) + + // At this stage we only know it's probably an IP+ICMP header so verify + // that much. + checker.IPv4(t, replyIPHeader, + checker.SrcAddr(ipv4Addr.Address), + checker.DstAddr(remoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Checksum(), + ), + ) + + // Don't proceed any further if the checker found problems. + if t.Failed() { + t.FailNow() + } + + // OK it's ICMP. We can safely look at the type now. + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) + switch replyICMPHeader.Type() { + case header.ICMPv4ParamProblem: + if !test.shouldFail { + t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) + } + if !test.expectErrorICMP { + t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) + } + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Pointer(test.paramProblemPointer), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + return + case header.ICMPv4DstUnreachable: + if !test.shouldFail { + t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + if !test.expectErrorICMP { + t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + return + case header.ICMPv4EchoReply: + if test.shouldFail { + if !test.expectErrorICMP { + t.Error("got Echo Reply packet, want no response") + } else { + t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) + } + } + // If the IP options change size then the packet will change size, so + // some IP header fields will need to be adjusted for the checks. + sizeChange := len(test.replyOptions) - len(test.options) + + checker.IPv4(t, replyIPHeader, + checker.IPv4HeaderLength(ipHeaderLength+sizeChange), + checker.IPv4Options(test.replyOptions), + checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Seq(randomSequence), + checker.ICMPv4Ident(randomIdent), + ), + ) + default: + t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) + } + }) + } } // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { + // Make a complete array of the sourcePacket packet. + source := header.IPv4(packets[0].NetworkHeader().View()) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) + source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -127,350 +776,925 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn sourceCopy.SetChecksum(0) sourceCopy.SetFlagsFragmentOffset(0, 0) sourceCopy.SetTotalLength(0) - var offset uint16 // Build up an array of the bytes sent. - var reassembledPayload []byte + var reassembledPayload buffer.VectorisedView for i, packet := range packets { // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) + allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) + fragmentIPHeader := header.IPv4(allBytes.ToView()) + if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) + } + if got := len(fragmentIPHeader); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) + if got := fragmentIPHeader.TransportProtocol(); got != proto { + return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) + if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) + if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { + return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) } - if i < len(packets)-1 { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) + if wantFragments[i].more { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) + reassembledPayload.AppendView(packet.TransportHeader().View()) + reassembledPayload.Append(packet.Data) // Clear out the checksum and length from the ip because we can't compare // it. - sourceCopy.SetTotalLength(uint16(len(ip))) + sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) + if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) + + expected := buffer.View(source[source.HeaderLength():]) + if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } -} -type errorChannel struct { - *channel.Endpoint - Ch chan stack.PacketBuffer - packetCollectorErrors []*tcpip.Error + return nil } -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { - return &errorChannel{ - Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan stack.PacketBuffer, size), - packetCollectorErrors: packetCollectorErrors, - } +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 } -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } +var fragmentationTests = []struct { + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + payloadSize int + wantFragments []fragmentInfo +}{ + { + description: "No fragmentation", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 744, more: false}, + }, + }, + { + description: "Fragmented with the minimum mtu", + mtu: header.IPv4MinimumMTU, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv4MinimumMTU + 1, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso none", + mtu: 1280, + gso: &stack.GSO{Type: stack.GSONone}, + transportHeaderLength: 0, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 144, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: 1280, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 44, more: false}, + }, + }, + { + description: "Fragmented with MTU smaller than header", + mtu: 300, + gso: nil, + transportHeaderLength: 1000, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 280, more: true}, + {offset: 280, payloadSize: 280, more: true}, + {offset: 560, payloadSize: 280, more: true}, + {offset: 840, payloadSize: 280, more: true}, + {offset: 1120, payloadSize: 280, more: true}, + {offset: 1400, payloadSize: 100, more: false}, + }, + }, } -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - select { - case e.Ch <- pkt: - default: - } +func TestFragmentationWritePacket(t *testing.T) { + const ttl = 42 - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + source := pkt.Clone() + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != nil { + t.Fatalf("r.WritePacket(_, _, _) = %s", err) + } + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) } - return nextError } -type context struct { - stack.Route - linkEP *errorChannel -} - -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - s.CreateNIC(1, ep) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - return context{ - Route: r, - linkEP: ep, +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + writePacketsTests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, } -} + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) -func TestFragmentation(t *testing.T) { - var manyPayloadViewsSizes [1000]int - for i := range manyPayloadViewsSizes { - manyPayloadViewsSizes[i] = 7 - } - fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int - }{ - {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, - {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, - } - - for _, ft := range fragTests { - t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := stack.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - } - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) - if err != nil { - t.Errorf("err got %v, want %v", err, nil) - } + for _, test := range writePacketsTests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } - var results []stack.PacketBuffer - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } - } + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) - } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if err != nil { + t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) + } + if n != wantTotalPackets { + t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) } - compareFragments(t, results, source, ft.mtu) }) } } -// TestFragmentationErrors checks that errors are returned from write packet +// TestFragmentationErrors checks that errors are returned from WritePacket // correctly. func TestFragmentationErrors(t *testing.T) { - fragTests := []struct { + const ttl = 42 + + tests := []struct { description string mtu uint32 - hdrLength int - payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error + transportHeaderLength int + payloadSize int + allowPackets int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + { + description: "No frag", + mtu: 2000, + payloadSize: 1000, + transportHeaderLength: 0, + allowPackets: 0, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on first frag", + mtu: 500, + payloadSize: 1000, + transportHeaderLength: 0, + allowPackets: 0, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on second frag", + mtu: 500, + payloadSize: 1000, + transportHeaderLength: 0, + allowPackets: 1, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on first frag MTU smaller than header", + mtu: 500, + transportHeaderLength: 1000, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 4, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error when MTU is smaller than IPv4 minimum MTU", + mtu: header.IPv4MinimumMTU - 1, + transportHeaderLength: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, + }, } - for _, ft := range fragTests { + for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) } }) } } func TestInvalidFragments(t *testing.T) { - // These packets have both IHL and TotalLength set to 0. - testCases := []struct { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 6 + ) + + payloadGen := func(payloadLen int) []byte { + payload := make([]byte, payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = 0x30 + } + return payload + } + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + autoChecksum bool // if true, the Checksum field will be overwritten. + } + + tests := []struct { name string - packets [][]byte + fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 }{ { - "ihl_totallen_zero_valid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL and TotalLength zero, FragmentOffset non-zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, + FragmentOffset: 59776, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, + }, + { + name: "IHL and TotalLength zero, FragmentOffset zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, + }, + { + // Payload 17 octets and Fragment offset 65520 + // Leading to the fragment end to be past 65536. + name: "fragment ends past 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 17, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(17), + autoChecksum: true, + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + }, + { + // Payload 16 octets and fragment offset 65520 + // Leading to the fragment end to be exactly 65536. + name: "fragment ends exactly at 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(16), + autoChecksum: true, + }, + }, + wantMalformedIPPackets: 0, + wantMalformedFragments: 0, + }, + { + name: "IHL less than IPv4 minimum size", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 1944, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize - 12, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 1, - 0, + wantMalformedIPPackets: 2, + wantMalformedFragments: 0, }, { - "ihl_totallen_zero_invalid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "fragment with short TotalLength and extra payload", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 28816, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 4, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 1, - 0, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, { - // Total Length of 37(20 bytes IP header + 17 bytes of - // payload) - // Frag Offset of 0x1ffe = 8190*8 = 65520 - // Leading to the fragment end to be past 65535. - "ihl_totallen_valid_invalid_frag_offset_1", - [][]byte{ - {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "multiple fragments with More Fragments flag set to false", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 128, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, - // The following 3 tests were found by running a fuzzer and were - // triggering a panic in the IPv4 reassembler code. + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + }, + }) + e := channel.New(0, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + len(f.payload) + hdr := buffer.NewPrependable(pktSize) + + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + copy(ip[header.IPv4MinimumSize:], f.payload) + + if f.autoChecksum { + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + } + + vv := hdr.View().ToVectorisedView() + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { + t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { + t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) + } + }) + } +} + +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 99 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ { - "ihl_less_than_ipv4_minimum_size_1", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "first fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, }, - 2, - 0, + expectICMP: true, }, { - "ihl_less_than_ipv4_minimum_size_2", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "two first fragments", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, }, - 2, - 0, + expectICMP: true, }, { - "ihl_less_than_ipv4_minimum_size_3", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "second fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, }, - 2, - 0, + expectICMP: false, }, { - "fragment_with_short_total_len_extra_payload", - [][]byte{ - {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, }, - 1, - 1, + expectICMP: true, }, { - "multiple_fragments_with_more_fragments_set_to_false", - [][]byte{ - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, }, - 1, - 1, + expectICMP: true, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - const nicID tcpip.NICID = 42 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, }, + Clock: clock, }) + e := channel.New(1, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + hdr := buffer.NewPrependable(pktSize) + + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) - var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) - var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) - ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicID, sniffer.New(ep)) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) - for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, }) + + if firstFragmentSent == nil && ip.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(header.IPv4ProtocolNumber, pkt) } - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { - t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) + clock.Advance(ipv4.ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { - t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) } + + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(firstFragmentSent)), + ), + ) }) } } @@ -478,12 +1702,16 @@ func TestInvalidFragments(t *testing.T) { // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { - const addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 - const addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 - const nicID = 1 + const ( + nicID = 1 + + addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 + addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 + addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3 + ) // Build and return a UDP header containing payload. - udpGen := func(payloadLen int, multiplier uint8) buffer.View { + udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View { payload := buffer.NewView(payloadLen) for i := 0; i < len(payload); i++ { payload[i] = uint8(i) * multiplier @@ -499,20 +1727,32 @@ func TestReceiveFragments(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) sum = header.Checksum(payload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) return hdr.View() } // UDP header plus a payload of 0..256 - ipv4Payload1 := udpGen(256, 1) - udpPayload1 := ipv4Payload1[header.UDPMinimumSize:] + ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2) + udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:] + ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2) + udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:] // UDP header plus a payload of 0..256 in increments of 2. - ipv4Payload2 := udpGen(128, 2) - udpPayload2 := ipv4Payload2[header.UDPMinimumSize:] + ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2) + udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:] + // UDP header plus a payload of 0..256 in increments of 3. + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 791 section 3.1 page 14). + ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) + udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] + // Used to test the max reassembled payload length (65,535 octets). + ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) + udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] type fragmentData struct { + srcAddr tcpip.Address + dstAddr tcpip.Address id uint16 flags uint8 fragmentOffset uint16 @@ -528,22 +1768,40 @@ func TestReceiveFragments(t *testing.T) { name: "No fragmentation", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 0, - payload: ipv4Payload1, + payload: ipv4Payload1Addr1ToAddr2, }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "No fragmentation with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2, + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, }, { name: "More fragments without payload", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1, + payload: ipv4Payload1Addr1ToAddr2, }, }, expectedPayloads: nil, @@ -552,10 +1810,12 @@ func TestReceiveFragments(t *testing.T) { name: "Non-zero fragment offset without payload", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 8, - payload: ipv4Payload1, + payload: ipv4Payload1Addr1ToAddr2, }, }, expectedPayloads: nil, @@ -564,34 +1824,108 @@ func TestReceiveFragments(t *testing.T) { name: "Two fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments out of order", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload3Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload3Addr1ToAddr2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload3Addr1ToAddr2[:63], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 63, + payload: ipv4Payload3Addr1ToAddr2[63:], }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: nil, }, { name: "Second fragment has MoreFlags set", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload1Addr1ToAddr2[64:], }, }, expectedPayloads: nil, @@ -600,16 +1934,20 @@ func TestReceiveFragments(t *testing.T) { name: "Two fragments with different IDs", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 2, flags: 0, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload1Addr1ToAddr2[64:], }, }, expectedPayloads: nil, @@ -618,31 +1956,113 @@ func TestReceiveFragments(t *testing.T) { name: "Two interleaved fragmented packets", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload1[:64], + payload: ipv4Payload1Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 2, flags: header.IPv4FlagMoreFragments, fragmentOffset: 0, - payload: ipv4Payload2[:64], + payload: ipv4Payload2Addr1ToAddr2[:64], }, { + srcAddr: addr1, + dstAddr: addr2, id: 1, flags: 0, fragmentOffset: 64, - payload: ipv4Payload1[64:], + payload: ipv4Payload1Addr1ToAddr2[64:], }, { + srcAddr: addr1, + dstAddr: addr2, id: 2, flags: 0, fragmentOffset: 64, - payload: ipv4Payload2[64:], + payload: ipv4Payload2Addr1ToAddr2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, + }, + { + name: "Two interleaved fragmented packets from different sources but with same ID", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + { + srcAddr: addr3, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr3ToAddr2[:32], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1Addr1ToAddr2[64:], + }, + { + srcAddr: addr3, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 32, + payload: ipv4Payload1Addr3ToAddr2[32:], + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, + }, + { + name: "Fragment without followup", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1Addr1ToAddr2[:64], + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], }, }, - expectedPayloads: [][]byte{udpPayload1, udpPayload2}, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, } @@ -650,8 +2070,8 @@ func TestReceiveFragments(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Setup a stack and endpoint. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -691,16 +2111,17 @@ func TestReceiveFragments(t *testing.T) { FragmentOffset: frag.fragmentOffset, TTL: 64, Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: addr1, - DstAddr: addr2, + SrcAddr: frag.srcAddr, + DstAddr: frag.dstAddr, }) + ip.SetChecksum(^ip.CalculateChecksum()) vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { @@ -723,3 +2144,394 @@ func TestReceiveFragments(t *testing.T) { }) } } + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + // Parameterize the tests to run with both WritePacket and WritePackets. + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) + + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) + } + { + mask := tcpip.AddressMask(header.IPv4Broadcast) + subnet, err := tcpip.NewSubnet(dst, mask) + if err != nil { + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) + } + return rt +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 8, + }, + } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, + TTL: ipv4.DefaultTTL, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + checker.ICMPv4Code(header.ICMPv4UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(1, defaultMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a ARP request since link address resolution should be + // performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != arp.ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) + } + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(p.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) + } + } + + // Send an ARP reply to complete link address resolution. + { + hdr := buffer.View(make([]byte, header.ARPSize)) + packet := header.ARP(hdr) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), host2NICLinkAddr) + copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) + copy(packet.HardwareAddressTarget(), host1NICLinkAddr) + copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) + e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 3f71fc520..0ac24a6fb 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -5,14 +5,18 @@ package(licenses = ["notice"]) go_library( name = "ipv6", srcs = [ + "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "ndp.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", @@ -32,13 +36,16 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go new file mode 100644 index 000000000..09ba133b1 --- /dev/null +++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. + +package ipv6 + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DHCPv6NoConfiguration-1] + _ = x[DHCPv6ManagedAddress-2] + _ = x[DHCPv6OtherConfigurations-3] +} + +const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" + +var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} + +func (i DHCPv6ConfigurationFromNDPRA) String() string { + i -= 1 + if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..3c15e41a7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -39,8 +39,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // is truncated, which would cause IsValid to return false. // // Drop packet if it doesn't have the basic IPv6 header or if the - // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -67,20 +68,76 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { +// getLinkAddrOption searches NDP options for a given link address option using +// the provided getAddr function as a filter. Returns the link address if +// found; otherwise, returns the zero link address value. Also returns true if +// the options are valid as per the wire format, false otherwise. +func getLinkAddrOption(it header.NDPOptionIterator, getAddr func(header.NDPOption) tcpip.LinkAddress) (tcpip.LinkAddress, bool) { + var linkAddr tcpip.LinkAddress + for { + opt, done, err := it.Next() + if err != nil { + return "", false + } + if done { + break + } + if addr := getAddr(opt); len(addr) != 0 { + // No RFCs define what to do when an NDP message has multiple Link-Layer + // Address options. Since no interface can have multiple link-layer + // addresses, we consider such messages invalid. + if len(linkAddr) != 0 { + return "", false + } + linkAddr = addr + } + } + return linkAddr, true +} + +// getSourceLinkAddr searches NDP options for the source link address option. +// Returns the link address if found; otherwise, returns the zero link address +// value. Also returns true if the options are valid as per the wire format, +// false otherwise. +func getSourceLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { + return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress { + if src, ok := opt.(header.NDPSourceLinkLayerAddressOption); ok { + return src.EthernetAddress() + } + return "" + }) +} + +// getTargetLinkAddr searches NDP options for the target link address option. +// Returns the link address if found; otherwise, returns the zero link address +// value. Also returns true if the options are valid as per the wire format, +// false otherwise. +func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { + return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress { + if dst, ok := opt.(header.NDPTargetLinkLayerAddressOption); ok { + return dst.EthernetAddress() + } + return "" + }) +} + +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { received.Invalid.Increment() return } h := header.ICMPv6(v) - iph := header.IPv6(netHeader) + iph := header.IPv6(pkt.NetworkHeader().View()) // Validate ICMPv6 checksum before processing the packet. // @@ -113,8 +170,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -125,13 +185,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { + case header.ICMPv6NetworkUnreachable: + e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } @@ -141,22 +203,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // NDP messages cannot be fragmented. Also note that in the common case NDP // datagrams are very small and ToView() will not incur allocations. ns := header.NDPNeighborSolicit(payload.ToView()) - it, err := ns.Options().Iter(true) - if err != nil { - // If we have a malformed NDP NS option, drop the packet. + targetAddr := ns.TargetAddress() + + // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast + // address. + if header.IsV6MulticastAddress(targetAddr) { received.Invalid.Increment() return } - targetAddr := ns.TargetAddress() - s := r.Stack() - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now, drop this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // If the target address is tentative and the source of the packet is a // unicast (specified) address, then the source of the packet is // attempting to perform address resolution on the target. In this case, @@ -169,7 +225,20 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // stack know so it can handle such a scenario and do nothing further with // the NS. if r.RemoteAddress == header.IPv6Any { - s.DupTentativeAddrDetected(e.nicID, targetAddr) + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } } // Do not handle neighbor solicitations targeted to an address that is @@ -181,48 +250,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // so the packet is processed as defined in RFC 4861, as per RFC 4862 // section 5.4.3. - // Is the NS targetting us? - if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { + // Is the NS targeting us? + if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } - // If the NS message contains the Source Link-Layer Address option, update - // the link address cache with the value of the option. - // - // TODO(b/148429853): Properly process the NS message and do Neighbor - // Unreachability Detection. var sourceLinkAddr tcpip.LinkAddress - for { - opt, done, err := it.Next() + { + it, err := ns.Options().Iter(false /* check */) if err != nil { - // This should never happen as Iter(true) above did not return an error. - panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) - } - if done { - break + // Options are not valid as per the wire format, silently drop the + // packet. + received.Invalid.Increment() + return } - switch opt := opt.(type) { - case header.NDPSourceLinkLayerAddressOption: - // No RFCs define what to do when an NS message has multiple Source - // Link-Layer Address options. Since no interface can have multiple - // link-layer addresses, we consider such messages invalid. - if len(sourceLinkAddr) != 0 { - received.Invalid.Increment() - return - } - - sourceLinkAddr = opt.EthernetAddress() + sourceLinkAddr, ok = getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return } } - unspecifiedSource := r.RemoteAddress == header.IPv6Any - // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. + unspecifiedSource := r.RemoteAddress == header.IPv6Any if len(sourceLinkAddr) == 0 { if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { received.Invalid.Increment() @@ -231,55 +286,88 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P } else if unspecifiedSource { received.Invalid.Increment() return + } else if e.nud != nil { + e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) - } - - // ICMPv6 Neighbor Solicit messages are always sent to - // specially crafted IPv6 multicast addresses. As a result, the - // route we end up with here has as its LocalAddress such a - // multicast address. It would be nonsense to claim that our - // source address is a multicast address, so we manually set - // the source address to the target address requested in the - // solicit message. Since that requires mutating the route, we - // must first clone it. - r := r.Clone() - defer r.Release() - r.LocalAddress = targetAddr + e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) + } - // As per RFC 4861 section 7.2.4, if the the source of the solicitation is - // the unspecified address, the node MUST set the Solicited flag to zero and - // multicast the advertisement to the all-nodes address. - solicited := true + // As per RFC 4861 section 7.1.1: + // A node MUST silently discard any received Neighbor Solicitation + // messages that do not satisfy all of the following validity checks: + // ... + // - If the IP source address is the unspecified address, the IP + // destination address is a solicited-node multicast address. + if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + received.Invalid.Increment() + return + } + + // As per RFC 4861 section 7.2.4: + // + // If the source of the solicitation is the unspecified address, the node + // MUST [...] and multicast the advertisement to the all-nodes address. + // + remoteAddr := r.RemoteAddress if unspecifiedSource { - solicited = false - r.RemoteAddress = header.IPv6AllNodesMulticastAddress + remoteAddr = header.IPv6AllNodesMulticastAddress } - // If the NS has a source link-layer option, use the link address it - // specifies as the remote link address for the response instead of the - // source link address of the packet. + // Even if we were able to receive a packet from some remote, we may not + // have a route to it - the remote may be blocked via routing rules. We must + // always consult our routing table and find a route to the remote before + // sending any packet. + r, err := e.protocol.stack.FindRoute(e.nic.ID(), targetAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // If the NS has a source link-layer option, resolve the route immediately + // to avoid querying the neighbor table when the neighbor entry was updated + // as probing the neighbor table for a link address will transition the + // entry's state from stale to delay. + // + // Note, if the source link address is unspecified and this is a unicast + // solicitation, we may need to perform neighbor discovery to send the + // neighbor advertisement response. This is expected as per RFC 4861 section + // 7.2.4: + // + // Because unicast Neighbor Solicitations are not required to include a + // Source Link-Layer Address, it is possible that a node sending a + // solicited Neighbor Advertisement does not have a corresponding link- + // layer address for its neighbor in its Neighbor Cache. In such + // situations, a node will first have to use Neighbor Discovery to + // determine the link-layer address of its neighbor (i.e., send out a + // multicast Neighbor Solicitation). // - // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link - // address cache for the right destination link address instead of manually - // patching the route with the remote link address if one is specified in a - // Source Link-Layer Address option. if len(sourceLinkAddr) != 0 { - r.RemoteLinkAddress = sourceLinkAddr + r.ResolveWith(sourceLinkAddr) } optsSerializer := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), + header.NDPTargetLinkLayerAddressOption(e.nic.LinkAddress()), } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length() + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize, + }) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) - na.SetSolicitedFlag(solicited) + + // As per RFC 4861 section 7.2.4: + // + // If the source of the solicitation is the unspecified address, the node + // MUST set the Solicited flag to zero and [..]. Otherwise, the node MUST + // set the Solicited flag to one and [..]. + // + na.SetSolicitedFlag(!unspecifiedSource) na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) - opts := na.Options() - opts.Serialize(optsSerializer) + na.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) @@ -288,9 +376,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.Dropped.Increment() return } @@ -298,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize { received.Invalid.Increment() return } @@ -308,28 +394,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // 5, NDP messages cannot be fragmented. Also note that in the common case // NDP datagrams are very small and ToView() will not incur allocations. na := header.NDPNeighborAdvert(payload.ToView()) - it, err := na.Options().Iter(true) - if err != nil { - // If we have a malformed NDP NA option, drop the packet. - received.Invalid.Increment() - return - } - targetAddr := na.TargetAddress() - stack := r.Stack() - - if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now short-circuit this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - stack.DupTentativeAddrDetected(e.nicID, targetAddr) + // + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } + return + } + + it, err := na.Options().Iter(false /* check */) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.Invalid.Increment() return } @@ -342,58 +434,59 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // TODO(b/143147598): Handle the scenario described above. Also inform the // netstack integration that a duplicate address was detected outside of // DAD. + targetLinkAddr, ok := getTargetLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - // - // TODO(b/148429853): Properly process the NA message and do Neighbor - // Unreachability Detection. - var targetLinkAddr tcpip.LinkAddress - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) - } - if done { - break - } - - switch opt := opt.(type) { - case header.NDPTargetLinkLayerAddressOption: - // No RFCs define what to do when an NA message has multiple Target - // Link-Layer Address options. Since no interface can have multiple - // link-layer addresses, we consider such messages invalid. - if len(targetLinkAddr) != 0 { - received.Invalid.Increment() - return - } - - targetLinkAddr = opt.EthernetAddress() + if e.nud == nil { + if len(targetLinkAddr) != 0 { + e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) } + return } - if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) - } + e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + Solicited: na.SolicitedFlag(), + Override: na.OverrideFlag(), + IsRouter: na.RouterFlag(), + }) case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { received.Invalid.Increment() return } - pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + + // As per RFC 4291 section 2.7, multicast addresses must not be used as + // source addresses in IPv6 packets. + localAddr := r.LocalAddress + if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, + Data: pkt.Data, + }) + packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ - Header: hdr, - Data: pkt.Data, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -415,27 +508,75 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterSolicit: received.RouterSolicit.Increment() - if !isNDPValid() { + + // + // Validate the RS as per RFC 4861 section 6.1.1. + // + + // Is the NDP payload of sufficient size to hold a Router Solictation? + if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.Invalid.Increment() return } - case header.ICMPv6RouterAdvert: - received.RouterAdvert.Increment() + stack := r.Stack() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + // Is the networking stack operating as a router? + if !stack.Forwarding(ProtocolNumber) { + // ... No, silently drop the packet. + received.RouterOnlyPacketsDroppedByHost.Increment() + return + } + + // Note that in the common case NDP datagrams are very small and ToView() + // will not incur allocations. + rs := header.NDPRouterSolicit(payload.ToView()) + it, err := rs.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. received.Invalid.Increment() return } - routerAddr := iph.SourceAddress() + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } + + // If the RS message has the source link layer option, update the link + // address cache with the link address for the source of the message. + if len(sourceLinkAddr) != 0 { + // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST + // NOT be included when the source IP address is the unspecified address. + // Otherwise, it SHOULD be included on link layers that have addresses. + if r.RemoteAddress == header.IPv6Any { + received.Invalid.Increment() + return + } + + if e.nud != nil { + // A RS with a specified source IP address modifies the NUD state + // machine in the same way a reachability probe would. + e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + } + } + + case header.ICMPv6RouterAdvert: + received.RouterAdvert.Increment() // // Validate the RA as per RFC 4861 section 6.1.2. // + // Is the NDP payload of sufficient size to hold a Router Advertisement? + if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + received.Invalid.Increment() + return + } + + routerAddr := iph.SourceAddress() + // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { // ...No, silently drop the packet. @@ -443,16 +584,18 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. + // Note that in the common case NDP datagrams are very small and ToView() + // will not incur allocations. ra := header.NDPRouterAdvert(payload.ToView()) - opts := ra.Options() + it, err := ra.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. + received.Invalid.Increment() + return + } - // Are options valid as per the wire format? - if _, err := opts.Iter(true); err != nil { - // ...No, silently drop the packet. + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { received.Invalid.Increment() return } @@ -462,12 +605,33 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // as RFC 4861 section 6.1.2 is concerned. // - // Tell the NIC to handle the RA. - stack := r.Stack() - rxNICID := r.NICID() - stack.HandleNDPRA(rxNICID, routerAddr, ra) + // If the RA has the source link layer option, update the link address + // cache with the link address for the advertised router. + if len(sourceLinkAddr) != 0 && e.nud != nil { + e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + } + + e.mu.Lock() + e.mu.ndp.handleRA(routerAddr, ra) + e.mu.Unlock() case header.ICMPv6RedirectMsg: + // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating + // this redirect message, as per RFC 4871 section 7.3.3: + // + // "A Neighbor Cache entry enters the STALE state when created as a + // result of receiving packets other than solicited Neighbor + // Advertisements (i.e., Router Solicitations, Router Advertisements, + // Redirects, and Neighbor Solicitations). These packets contain the + // link-layer address of either the sender or, in the case of Redirect, + // the redirection target. However, receipt of these link-layer + // addresses does not confirm reachability of the forward-direction path + // to that node. Placing a newly created Neighbor Cache entry for which + // the link-layer address is known in the STALE state provides assurance + // that path failures are detected quickly. In addition, should a cached + // link-layer address be modified due to receiving one of the above + // messages, the state SHOULD also be set to STALE to provide prompt + // verification that the path to the new link-layer address is working." received.RedirectMsg.Increment() if !isNDPValid() { received.Invalid.Increment() @@ -479,20 +643,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P } } -const ( - ndpSolicitedFlag = 1 << 6 - ndpOverrideFlag = 1 << 5 - - ndpOptSrcLinkAddr = 1 - ndpOptDstLinkAddr = 2 - - icmpV6FlagOffset = 4 - icmpV6OptOffset = 24 - icmpV6LengthOffset = 25 -) - -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - var _ stack.LinkAddressResolver = (*protocol)(nil) // LinkAddressProtocol implements stack.LinkAddressResolver. @@ -501,40 +651,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { - snaddr := header.SolicitedNodeAddr(addr) - - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the - // route here. Note, we would need the nicID to do this properly so the right - // NIC (associated to linkEP) is used to send the NDP NS message. - r := &stack.Route{ - LocalAddress: localAddr, - RemoteAddress: snaddr, - RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + remoteAddr := targetAddr + if len(remoteLinkAddr) == 0 { + remoteAddr = header.SolicitedNodeAddr(targetAddr) + remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - copy(pkt[icmpV6OptOffset-len(addr):], addr) - pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - length := uint16(hdr.UsedLength()) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ - Header: hdr, + r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + r.ResolveWith(remoteLinkAddr) + + optsSerializer := header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), + } + neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, }) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) + packet.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(packet.NDPPayload()) + ns.SetTargetAddress(targetAddr) + ns.Options().Serialize(optsSerializer) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + stat := p.stack.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt); err != nil { + stat.Dropped.Increment() + return err + } + + stat.NeighborSolicit.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -544,3 +700,192 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return tcpip.LinkAddress([]byte(nil)), false } + +// ======= ICMP Error packet generation ========= + +// icmpReason is a marker interface for IPv6 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonParameterProblem is an error during processing of extension headers +// or the fixed header defined in RFC 4443 section 3.4. +type icmpReasonParameterProblem struct { + code header.ICMPv6Code + + // respondToMulticast indicates that we are sending a packet that falls under + // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondToMulticast bool + + // pointer is defined in the RFC 4443 setion 3.4 which reads: + // + // Pointer Identifies the octet offset within the invoking packet + // where the error was detected. + // + // The pointer will point beyond the end of the ICMPv6 + // packet if the field in error is beyond what can fit + // in the maximum size of an ICMPv6 error message. + pointer uint32 +} + +func (*icmpReasonParameterProblem) isICMPReason() {} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv6 and sends it. +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + // Only send ICMP error if the address is not a multicast v6 + // address and the source is not the unspecified address. + // + // There are exceptions to this rule. + // See: point e.3) RFC 4443 section-2.4 + // + // (e) An ICMPv6 error message MUST NOT be originated as a result of + // receiving the following: + // + // (e.1) An ICMPv6 error message. + // + // (e.2) An ICMPv6 redirect message [IPv6-DISC]. + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + // + var allowResponseToMulticast bool + if reason, ok := reason.(*icmpReasonParameterProblem); ok { + allowResponseToMulticast = reason.respondToMulticast + } + + if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + return nil + } + + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + stats := p.stack.Stats().ICMP + sent := stats.V6PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + + if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { + // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. + // Unfortunately at this time ICMP Packets do not have a transport + // header separated out. It is in the Data part so we need to + // separate it out now. We will just pretend it is a minimal length + // ICMP packet as we don't really care if any later bits of a + // larger ICMP packet are in the header view or in the Data view. + transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize) + if !ok { + return nil + } + typ := header.ICMPv6(transport).Type() + if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { + return nil + } + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(route.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + available := int(mtu) - headerLen + if available < header.IPv6MinimumSize { + return nil + } + payloadLen := network.Size() + transport.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + payload := network.ToVectorisedView() + payload.AppendView(transport) + payload.Append(pkt.Data) + payload.CapLength(payloadLen) + + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + + icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + var counter *tcpip.StatCounter + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + icmpHdr.SetType(header.ICMPv6ParamProblem) + icmpHdr.SetCode(reason.code) + icmpHdr.SetTypeSpecific(reason.pointer) + counter = sent.ParamProblem + case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) + counter = sent.TimeExceeded + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data)) + if err := route.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + newPkt, + ); err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..aa8b5f2e5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,37 +16,57 @@ package ipv6 import ( "context" + "net" "reflect" "strings" "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) const ( + nicID = 1 + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + + defaultChannelSize = 1 + defaultMTU = 65536 + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 30 * time.Second ) var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) + lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { stack.LinkEndpoint } +func (*stubLinkEndpoint) MTU() uint32 { + return defaultMTU +} + func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 + // Indicate that resolution for link layer addresses is required to send + // packets over this link. This is needed so the NIC knows to allocate a + // neighbor table. + return stack.CapabilityResolutionRequired } func (*stubLinkEndpoint) MaxHeaderLength() uint16 { @@ -57,7 +77,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return nil } @@ -67,7 +87,8 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { + return stack.TransportPacketHandled } type stubLinkAddressCache struct { @@ -81,16 +102,225 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { } +type stubNUDHandler struct { + probeCount int + confirmationCount int +} + +var _ stack.NUDHandler = (*stubNUDHandler)(nil) + +func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { + s.probeCount++ +} + +func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { + s.confirmationCount++ +} + +func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { +} + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.LinkEndpoint + + nicID tcpip.NICID +} + +func (*testInterface) ID() tcpip.NICID { + return nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestICMPCounts(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: test.useNeighborCache, + }) + { + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + } + defer r.Release() + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + typ header.ICMPv6Type + size int + extraData []byte + }{ + { + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + }, + { + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + }, + { + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + }, + { + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, + { + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + }, + { + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + }, + } + + handleIPv6Payload := func(icmp header.ICMPv6) { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ep.HandlePacket(&r, pkt) + } + + for _, typ := range types { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) + } + + // Construct an empty ICMP packet so that + // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + + icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { + if got, want := s.Value(), uint64(1); got != want { + t.Errorf("got %s = %d, want = %d", name, got, want) + } + }) + if t.Failed() { + t.Logf("stats:\n%+v", s.Stats()) + } + }) + } +} + +func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: true, }) { - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } } @@ -102,7 +332,7 @@ func TestICMPCounts(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } @@ -111,14 +341,16 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) } defer r.Release() @@ -179,36 +411,33 @@ func TestICMPCounts(t *testing.T) { }, } - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + handleIPv6Payload := func(icmp header.ICMPv6) { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) + ep.HandlePacket(&r, pkt) } for _, typ := range types { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - handleIPv6Payload(hdr) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -252,35 +481,34 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ s0: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), s1: stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }), } - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { + if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { + if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { + if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } @@ -291,7 +519,7 @@ func newTestContext(t *testing.T) *testContext { c.s0.SetRouteTable( []tcpip.Route{{ Destination: subnet0, - NIC: 1, + NIC: nicID, }}, ) subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) @@ -301,7 +529,7 @@ func newTestContext(t *testing.T) *testContext { c.s1.SetRouteTable( []tcpip.Route{{ Destination: subnet1, - NIC: 1, + NIC: nicID, }}, ) @@ -325,12 +553,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. pi, _ := args.src.ReadContext(context.Background()) { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ - Data: vv, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), }) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) } if pi.Proto != ProtocolNumber { @@ -342,7 +568,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } - ipv6 := header.IPv6(pi.Pkt.Header.View()) + // Pull the full payload since network header. Needed for header.IPv6 to + // extract its payload. + ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -362,9 +590,9 @@ func TestLinkResolution(t *testing.T) { c := newTestContext(t) defer c.cleanup() - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) } defer r.Release() @@ -379,14 +607,14 @@ func TestLinkResolution(t *testing.T) { var wq waiter.Queue ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) + _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}) if resCh != nil { if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) + t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err) } for _, args := range []routeArgs{ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, @@ -402,7 +630,7 @@ func TestLinkResolution(t *testing.T) { continue } if err != nil { - t.Fatalf("ep.Write(_) = _, _, %s", err) + t.Fatalf("ep.Write(_) = (_, _, %s)", err) } break } @@ -427,6 +655,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { size int extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool }{ { name: "DstUnreachable", @@ -483,6 +712,8 @@ func TestICMPChecksumValidationSimple(t *testing.T) { statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }, + // Hosts MUST silently discard any received Router Solicitation messages. + routerOnly: true, }, { name: "RouterAdvert", @@ -519,86 +750,133 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }, } - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" + } + t.Run(name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr0) + + // Indicate that resolution for link layer addresses is required to + // send packets over this link. This is needed so the NIC knows to + // allocate a neighbor table. + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: test.useNeighborCache, + }) + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + handleIPv6Payload := func(checksum bool) { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + if checksum { + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + } + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), + }) + e.InjectInbound(ProtocolNumber, pkt) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Router only count should not have increased. + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if !isRouter && typ.routerOnly && test.useNeighborCache { + // Router only count should have increased. + if got := routerOnly.Value(); got != 1 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) + } + } + }) } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) - } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(typ.size + extraDataLen), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) } }) } @@ -699,13 +977,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } { @@ -716,7 +994,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } @@ -724,12 +1002,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { icmpSize := size + payloadSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) + icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) + icmpHdr.SetType(typ) + payloadFn(icmpHdr.Payload()) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -740,9 +1018,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -754,7 +1033,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Without setting checksum, the incoming packet should @@ -765,13 +1044,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { } // Rx count of type typ.typ should not have increased. if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // When checksum is set, it should be received. handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Invalid count should not have increased again. if got := invalid.Value(); got != 1 { @@ -876,14 +1155,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Run(typ.name, func(t *testing.T) { e := channel.New(10, 1280, linkAddr0) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -893,21 +1172,21 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + icmpHdr := header.ICMPv6(hdr.Prepend(size)) + icmpHdr.SetType(typ) payload := buffer.NewView(payloadSize) payloadFn(payload) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -918,9 +1197,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -932,7 +1212,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Without setting checksum, the incoming packet should @@ -943,13 +1223,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } // Rx count of type typ.typ should not have increased. if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // When checksum is set, it should be received. handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Invalid count should not have increased again. if got := invalid.Value(); got != 1 { @@ -958,3 +1238,573 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { }) } } + +func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + + snaddr := header.SolicitedNodeAddr(lladdr0) + mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) + + tests := []struct { + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedRemoteAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, + }, + { + name: "Multicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, + }, + { + name: "Multicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Unicast with no local address available", + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + } + + for _, test := range tests { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } + } + + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + } + if pkt.Route.RemoteAddress != test.expectedRemoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) + } + if pkt.Route.LocalAddress != lladdr1 { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) + } + checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), + checker.SrcAddr(lladdr1), + checker.DstAddr(test.expectedRemoteAddr), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), + )) + } +} + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + checker.ICMPv6Code(header.ICMPv6UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a neighbor solicitation since link address resolution should + // be performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}), + )) + } + + // Send a neighbor advertisement to complete link address resolution. + { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) + pkt := header.ICMPv6(hdr.Prepend(naSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), + }) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) + } +} + +func TestCallsToNeighborCache(t *testing.T) { + tests := []struct { + name string + createPacket func() header.ICMPv6 + multicast bool + source tcpip.Address + destination tcpip.Address + wantProbeCount int + wantConfirmationCount int + }{ + { + name: "Unicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "The source link-layer address option SHOULD be included in unicast + // solicitations." - RFC 4861 section 4.3 + // + // A Neighbor Advertisement needs to be sent in response, but the + // Neighbor Cache shouldn't be updated since we have no useful + // information about the sender. + wantProbeCount: 0, + }, + { + name: "Unicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantProbeCount: 1, + }, + { + name: "Multicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + // "The source link-layer address option MUST be included in multicast + // solicitations." - RFC 4861 section 4.3 + wantProbeCount: 0, + }, + { + name: "Multicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + wantProbeCount: 1, + }, + { + name: "Unicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "When responding to unicast solicitations, the target link-layer + // address option can be omitted since the sender of the solicitation has + // the correct link-layer address; otherwise, it would not be able to + // send the unicast solicitation in the first place." + // - RFC 4861 section 4.4 + wantConfirmationCount: 1, + }, + { + name: "Unicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantConfirmationCount: 1, + }, + { + name: "Multicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + // "Target link-layer address MUST be included for multicast solicitations + // in order to avoid infinite Neighbor Solicitation "recursion" when the + // peer node does not have a cache entry to return a Neighbor + // Advertisements message." - RFC 4861 section 4.4 + wantConfirmationCount: 0, + }, + { + name: "Multicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + wantConfirmationCount: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: true, + }) + { + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + nudHandler := &stubNUDHandler{} + ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + } + defer r.Release() + + // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. + r.LocalAddress = test.destination + + icmp := test.createPacket() + icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.RemoteAddress, + DstAddr: r.LocalAddress, + }) + ep.HandlePacket(&r, pkt) + + // Confirm the endpoint calls the correct NUDHandler method. + if nudHandler.probeCount != test.wantProbeCount { + t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount) + } + if nudHandler.confirmationCount != test.wantConfirmationCount { + t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index daf1fcbc6..1e38f3a9d 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,98 +12,373 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ipv6 contains the implementation of the ipv6 network protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.NewProtocol() as one of the network -// protocols when calling stack.New(). Then endpoints can be created by passing -// ipv6.ProtocolNumber as the network protocol number when calling -// Stack.NewEndpoint(). +// Package ipv6 contains the implementation of the ipv6 network protocol. package ipv6 import ( + "encoding/binary" "fmt" + "hash/fnv" + "sort" "sync/atomic" + "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( + // As per RFC 8200 section 4.5: + // If insufficient fragments are received to complete reassembly of a packet + // within 60 seconds of the reception of the first-arriving fragment of that + // packet, reassembly of that packet must be abandoned. + // + // Linux also uses 60 seconds for reassembly timeout: + // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 + ReassembleTimeout = 60 * time.Second + // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber - // maxTotalSize is maximum size that can be encoded in the 16-bit + // maxPayloadSize is the maximum size that can be encoded in the 16-bit // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff // DefaultTTL is the default hop limit for IPv6 Packets egressed by // Netstack. DefaultTTL = 64 + + // buckets for fragment identifiers + buckets = 2048 ) +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) +var _ stack.NDPEndpoint = (*endpoint)(nil) +var _ NDPEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int - linkEP stack.LinkEndpoint + nic stack.NetworkInterface linkAddrCache stack.LinkAddressCache + nud stack.NUDHandler dispatcher stack.TransportDispatcher - fragmentation *fragmentation.Fragmentation protocol *protocol + stack *stack.Stack + + // enabled is set to 1 when the endpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + ndp ndpState + } } -// DefaultTTL is the default hop limit for this endpoint. -func (e *endpoint) DefaultTTL() uint8 { - return e.protocol.DefaultTTL() +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it is passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. -func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) +// InvalidateDefaultRouter implements stack.NDPEndpoint. +func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.invalidateDefaultRouter(rtr) } -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID +// SetNDPConfigurations implements NDPEndpoint. +func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) { + c.validate() + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.configs = c } -// ID returns the ipv6 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id +// hasTentativeAddr returns true if addr is tentative on e. +func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { + e.mu.RLock() + addressEndpoint := e.getAddressRLocked(addr) + e.mu.RUnlock() + return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative } -// PrefixLen returns the ipv6 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen +// dupTentativeAddrDetected attempts to inform e that a tentative addr is a +// duplicate on a link. +// +// dupTentativeAddrDetected removes the tentative address if it exists. If the +// address was generated via SLAAC, an attempt is made to generate a new +// address. +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return tcpip.ErrBadAddress + } + + if addressEndpoint.GetKind() != stack.PermanentTentative { + return tcpip.ErrInvalidEndpointState + } + + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an + // attempt will be made to generate a new address for it. + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + return err + } + + prefix := addressEndpoint.AddressWithPrefix().Subnet() + + switch t := addressEndpoint.ConfigType(); t { + case stack.AddressConfigStatic: + case stack.AddressConfigSlaac: + e.mu.ndp.regenerateSLAACAddr(prefix) + case stack.AddressConfigSlaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + default: + panic(fmt.Sprintf("unrecognized address config type = %d", t)) + } + + return nil } -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.Enabled() { + return + } + + if forwarding { + // When transitioning into an IPv6 router, host-only state (NDP discovered + // routers, discovered on-link prefixes, and auto-generated addresses) is + // cleaned up/invalidated and NDP router solicitations are stopped. + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(true /* hostOnly */) + } else { + // When transitioning into an IPv6 host, NDP router solicitations are + // started. + e.mu.ndp.startSolicitingRouters() + } } -// MaxHeaderLength returns the maximum length needed by ipv6 headers (and -// underlying protocols). -func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil + } + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + // + // Also auto-generate an IPv6 link-local address based on the endpoint's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent + // state. + // + // Addresses may have aleady completed DAD but in the time since the endpoint + // was last enabled, other devices may have acquired the same addresses. + var err *tcpip.Error + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + addr := addressEndpoint.AddressWithPrefix().Address + if !header.IsV6UnicastAddress(addr) { + return true + } + + switch addressEndpoint.GetKind() { + case stack.Permanent: + addressEndpoint.SetKind(stack.PermanentTentative) + fallthrough + case stack.PermanentTentative: + err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint) + return err == nil + default: + return true + } + }) + if err != nil { + return err + } + + // Do not auto-generate an IPv6 link-local address for loopback devices. + if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + // The valid and preferred lifetime is infinite for the auto-generated + // link-local address. + e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) + } + + // If we are operating as a router, then do not solicit routers since we + // won't process the RAs anyway. + // + // Routers do not process Router Advertisements (RA) the same way a host + // does. That is, routers do not learn from RAs (e.g. on-link prefixes + // and default routers). Therefore, soliciting RAs from other routers on + // a link is unnecessary for routers. + if !e.protocol.Forwarding() { + e.mu.ndp.startSolicitingRouters() + } + + return nil } -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() +} + +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(false /* hostOnly */) + e.stopDADForPermanentAddressesLocked() + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) + } +} + +// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) stopDADForPermanentAddressesLocked() { + // Stop DAD for all the tentative unicast addresses. + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.GetKind() != stack.PermanentTentative { + return true + } + + addr := addressEndpoint.AddressWithPrefix().Address + if header.IsV6UnicastAddress(addr) { + e.mu.ndp.stopDuplicateAddressDetection(addr) + } + + return true + }) +} + +// DefaultTTL is the default hop limit for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return e.protocol.DefaultTTL() +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) + if err != nil { + return 0 } - return 0 + return networkMTU +} + +// MaxHeaderLength returns the maximum length needed by ipv6 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { - length := uint16(hdr.UsedLength() + payloadSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(params.Protocol), @@ -112,25 +387,97 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - return ip + pkt.NetworkProtocolNumber = ProtocolNumber +} + +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU +} + +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. The transport header protocol number is required to avoid +// parsing the IPv6 extension headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // maximum payload length because they should only be transmitted once. + fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7 + if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit { + // We need at least 8 bytes of space left for the fragmentable part because + // the fragment payload must obviously be non-zero and must be a multiple + // of 8 as per RFC 8200 section 4.5: + // Each complete fragment, except possibly the last ("rightmost") one, is + // an integer multiple of 8 octets long. + return 0, 1, tcpip.ErrMessageTooLong + } + + if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { + // As per RFC 8200 Section 4.5, the Transport Header is expected to be small + // enough to fit in the first fragment. + return 0, 1, tcpip.ErrMessageTooLong + } + + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + return n, pf.RemainingFragmentCount(), nil + } + } } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { + e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt, params.Protocol) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() + return nil + } + + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. + if pkt.NatDone { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) + return nil + } + } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // The inbound path expects an unparsed packet. + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) loopedR.Release() } @@ -138,11 +485,35 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + return err + } + + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) + return nil } -// WritePackets implements stack.LinkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") @@ -151,45 +522,163 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) - pb.NetworkHeader = buffer.View(ip) + e.addIPHeader(r, pb, params) + + networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + if packetMustBeFragmented(pb, networkMTU, gso) { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pb + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(pb, fragPkt) + pb = fragPkt + return nil + }); err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) + } + } + + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } + return n, err + } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + + // Slow path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + ep.HandlePacket(&route, pkt) + n++ + continue + } + } + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet -// supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { - // TODO(b/146666412): Support IPv6 header-included packets. - return tcpip.ErrNotSupported +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { + // The packet already has an IP header, but there are a few required checks. + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return tcpip.ErrMalformedHeader + } + ip := header.IPv6(h) + + // Always set the payload length. + pktSize := pkt.Data.Size() + ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + + // Set the source address when zero. + if ip.SourceAddress() == header.IPv6Any { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. + ip.SetDestinationAddress(r.RemoteAddress) + + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + proto, _, _, _, ok := parse.IPv6(pkt) + if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader + } + + return e.writePacket(r, nil /* gso */, pkt, proto) } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { return } - h := header.IPv6(headerView) - if !h.IsValid(pkt.Data.Size()) { + + h := header.IPv6(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] - pkt.Data.TrimFront(header.IPv6MinimumSize) - pkt.Data.CapLength(int(h.PayloadLength())) + // As per RFC 4291 section 2.7: + // Multicast addresses must not be used as source addresses in IPv6 + // packets or appear in any Routing header. + if header.IsV6MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data) + // vv consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader().View()) + vv.Append(pkt.Data) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false - for firstHeader := true; ; firstHeader = false { + // iptables filtering. All packets that reach here are intended for + // this machine and need not be forwarded. + ipt := e.protocol.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() + return + } + + for { + // Keep track of the start of the previous header so we can report the + // special case of a Hop by Hop at a location other than at the start. + previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -203,11 +692,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 - // (unrecognized next header) error in response to an extension header's - // Next Header field with the Hop By Hop extension header identifier. - if !firstHeader { + if previousHeaderStart != 0 { + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + }, pkt) return } @@ -229,13 +718,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) @@ -246,25 +747,27 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // As per RFC 8200 section 4.4, if a node encounters a routing header with // an unrecognized routing type value, with a non-zero Segments Left // value, the node must discard the packet and send an ICMP Parameter - // Problem, Code 0. If the Segments Left is 0, the node must ignore the - // Routing extension header and process the next header in the packet. + // Problem, Code 0 to the packet's Source Address, pointing to the + // unrecognized Routing Type. + // + // If the Segments Left is 0, the node must ignore the Routing extension + // header and process the next header in the packet. // // Note, the stack does not yet handle any type of routing extension // header, so we just make sure Segments Left is zero before processing // the next extension header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for - // unrecognized routing types with a non-zero Segments Left value. if extHdr.SegmentsLeft() != 0 { + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: it.ParseOffset(), + }, pkt) return } case header.IPv6FragmentExtHdr: hasFragmentHeader = true - fragmentOffset := extHdr.FragmentOffset() - more := extHdr.More() - if !more && fragmentOffset == 0 { + if extHdr.IsAtomic() { // This fragment extension header indicates that this packet is an // atomic fragment. An atomic fragment is a fragment that contains // all the data required to reassemble a full packet. As per RFC 6946, @@ -274,12 +777,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { continue } + fragmentFieldOffset := it.ParseOffset() + // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. - rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */) + rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */) - if fragmentOffset == 0 { + if extHdr.FragmentOffset() == 0 { // Check that the iterator ends with a raw payload as the first fragment // should include all headers up to and including any upper layer // headers, as per RFC 8200 section 4.5; only upper layer data @@ -290,7 +795,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { it, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() return } if done { @@ -331,32 +836,89 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { return } + // As per RFC 2460 Section 4.5: + // + // If the length of a fragment, as derived from the fragment packet's + // Payload Length field, is not a multiple of 8 octets and the M flag + // of that fragment is 1, then that fragment must be discarded and an + // ICMP Parameter Problem, Code 0, message should be sent to the source + // of the fragment, pointing to the Payload Length field of the + // fragment packet. + if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: header.IPv6PayloadLenOffset, + }, pkt) + return + } + // The packet is a fragment, let's try to reassemble it. - start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - last := start + uint16(fragmentPayloadLen) - 1 + start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < start { + // As per RFC 2460 Section 4.5: + // + // If the length and offset of a fragment are such that the Payload + // Length of the packet reassembled from that fragment would exceed + // 65,535 octets, then that fragment must be discarded and an ICMP + // Parameter Problem, Code 0, message should be sent to the source of + // the fragment, pointing to the Fragment Offset field of the fragment + // packet. + if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: fragmentFieldOffset, + }, pkt) return } - var ready bool - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf) + // Set up a callback in case we need to send a Time Exceeded Message as + // per RFC 2460 Section 4.5. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + r := r.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + } + r.Release() + } + } + + // Note that pkt doesn't have its transport header set after reassembly, + // and won't until DeliverNetworkPacket sets it. + data, proto, ready, err := e.protocol.fragmentation.Process( + // IPv6 ignores the Protocol field since the ID only needs to be unique + // across source-destination pairs, as per RFC 8200 section 4.5. + fragmentation.FragmentID{ + Source: h.SourceAddress(), + Destination: h.DestinationAddress(), + ID: extHdr.ID(), + }, + start, + start+uint16(fragmentPayloadLen)-1, + extHdr.More(), + uint8(rawPayload.Identifier), + rawPayload.Buf, + releaseCB, + ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return } + pkt.Data = data if ready { // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC - // 8200 section 4.5. - it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data) + // 8200 section 4.5. We also use the NextHeader value from the first + // fragment. + it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data) } case header.IPv6DestinationOptionsExtHdr: @@ -378,13 +940,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) @@ -394,23 +968,64 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { case header.IPv6RawPayloadHeader: // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + + // For unfragmented packets, extHdr still contains the transport header. + // Get rid of it. + // + // For reassembled fragments, pkt.TransportHeader is unset, so this is a + // no-op and pkt.Data begins with the transport header. + extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf + r.Stats().IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, pkt, hasFragmentHeader) + pkt.TransportProtocolNumber = p + e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. - e.dispatcher.DeliverTransportPacket(r, p, pkt) + switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + case stack.TransportPacketHandled: + case stack.TransportPacketDestinationPortUnreachable: + // As per RFC 4443 section 3.1: + // A destination node SHOULD originate a Destination Unreachable + // message with Code 4 in response to a packet for which the + // transport protocol (e.g., UDP) has no listener, if that transport + // protocol has no alternative means to inform the sender. + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC 8200 section 4. (page 7): + // Extension headers are numbered from IANA IP Protocol Numbers + // [IANA-PN], the same values used for IPv4 and IPv6. When + // processing a sequence of Next Header values in a packet, the + // first one that is not an extension header [IANA-EH] indicates + // that the next item in the packet is the corresponding upper-layer + // header. + // With more related information on page 8: + // If, as a result of processing a header, the destination node is + // required to proceed to the next header but the Next Header value + // in the current header is unrecognized by the node, it should + // discard the packet and send an ICMP Parameter Problem message to + // the source of the packet, with an ICMP Code value of 1 + // ("unrecognized Next Header type encountered") and the ICMP + // Pointer field containing the offset of the unrecognized value + // within the original packet. + // + // Which when taken together indicate that an unknown protocol should + // be treated as an unrecognized next header value. + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) + default: + panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) + } } default: - // If we receive a packet for an extension header we do not yet handle, - // drop the packet for now. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) r.Stats().UnknownProtocolRcvdPackets.Increment() return } @@ -418,18 +1033,343 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { } // Close cleans up resources associated with the endpoint. -func (*endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + e.disableLocked() + e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */) + e.stopDADForPermanentAddressesLocked() + e.mu.addressableEndpointState.Cleanup() + e.mu.Unlock() + + e.protocol.forgetEndpoint(e) +} // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + // TODO(b/169350103): add checks here after making sure we no longer receive + // an empty address. + e.mu.Lock() + defer e.mu.Unlock() + return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) +} + +// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but +// with locking requirements. +// +// addAndAcquirePermanentAddressLocked also joins the passed address's +// solicited-node multicast group and start duplicate address detection. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err != nil { + return nil, err + } + + if !header.IsV6UnicastAddress(addr.Address) { + return addressEndpoint, nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { + return nil, err + } + + addressEndpoint.SetKind(stack.PermanentTentative) + + if e.Enabled() { + if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil { + return nil, err + } + } + + return addressEndpoint, nil +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + return e.removePermanentEndpointLocked(addressEndpoint, true) +} + +// removePermanentEndpointLocked is like removePermanentAddressLocked except +// it works with a stack.AddressEndpoint. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error { + addr := addressEndpoint.AddressWithPrefix() + unicast := header.IsV6UnicastAddress(addr.Address) + if unicast { + e.mu.ndp.stopDuplicateAddressDetection(addr.Address) + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + switch addressEndpoint.ConfigType() { + case stack.AddressConfigSlaac: + e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case stack.AddressConfigSlaacTemp: + e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + } + } + + if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil { + return err + } + + if !unicast { + return nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + return nil +} + +// hasPermanentAddressLocked returns true if the endpoint has a permanent +// address equal to the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return false + } + return addressEndpoint.GetKind().IsPermanent() +} + +// getAddressRLocked returns the endpoint for the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { + return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) +} + +// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with +// locking requirements. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB) +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + // addrCandidate is a candidate for Source Address Selection, as per + // RFC 6724 section 5. + type addrCandidate struct { + addressEndpoint stack.AddressEndpoint + scope header.IPv6AddressScope + } + + if len(remoteAddr) == 0 { + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + var cs []addrCandidate + e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !addressEndpoint.IsAssigned(allowExpired) { + return + } + + addr := addressEndpoint.AddressWithPrefix().Address + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) + } + + cs = append(cs, addrCandidate{ + addressEndpoint: addressEndpoint, + scope: scope, + }) + }) + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return true + } + if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { + return saTemp + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if c.addressEndpoint.IncRef() { + return c.addressEndpoint + } + } + + return nil +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV6MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) + type protocol struct { + stack *stack.Stack + + mu struct { + sync.RWMutex + + eps map[*endpoint]struct{} + } + + ids []uint32 + hashIV uint32 + // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. defaultTTL uint32 + + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + + fragmentation *fragmentation.Fragmentation + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. + ndpConfigs NDPConfigurations + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + + // autoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -454,24 +1394,42 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - return &endpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, linkAddrCache: linkAddrCache, + nud: nud, dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, - }, nil + } + e.mu.addressableEndpointState.Init(e) + e.mu.ndp = ndpState{ + ep: e, + configs: p.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), + } + e.mu.ndp.initializeTempAddrState() + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.eps[e] = struct{}{} + return e +} + +func (p *protocol) forgetEndpoint(e *endpoint) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, e) } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case tcpip.DefaultTTLOption: - p.SetDefaultTTL(uint8(v)) + case *tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(*v)) return nil default: return tcpip.ErrUnknownProtocolOption @@ -479,7 +1437,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) @@ -505,17 +1463,193 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - mtu -= header.IPv6MinimumSize - if mtu <= maxPayloadSize { - return mtu +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) + if !ok { + return 0, false, false } - return maxPayloadSize + + return proto, !fragMore && fragOffset == 0, true } -// NewProtocol returns an IPv6 network protocol. -func NewProtocol() stack.NetworkProtocol { - return &protocol{defaultTTL: DefaultTTL} +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} + +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.SwapUint32(&p.forwarding, 1) == 0 + } + return atomic.SwapUint32(&p.forwarding, 0) == 1 +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.setForwarding(v) { + return + } + + for ep := range p.mu.eps { + ep.transitionForwarding(v) + } +} + +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload MTU and the length of every IPv6 header. +// Note that this is different than the Payload Length field of the IPv6 header, +// which includes the length of the extension headers. +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv6MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState + } + + // As per RFC 7112 section 5, we should discard packets if their IPv6 header + // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not + // support PMTU discovery: + // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain + // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280 + // bytes ensures that the header chain length does not exceed the IPv6 + // minimum MTU. + if networkHeadersLen > header.IPv6MinimumMTU { + return 0, tcpip.ErrMalformedHeader + } + + networkMTU := linkMTU - uint32(networkHeadersLen) + if networkMTU > maxPayloadSize { + networkMTU = maxPayloadSize + } + return networkMTU, nil +} + +// Options holds options to configure a new protocol. +type Options struct { + // NDPConfigs is the default NDP configurations used by interfaces. + NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // + // Note, setting this to true does not mean that a link-local address is + // assigned right away, or at all. If Duplicate Address Detection is enabled, + // an address is only assigned if it successfully resolves. If it fails, no + // further attempts are made to auto-generate an IPv6 link-local adddress. + // + // The generated link-local address follows RFC 4291 Appendix A guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte +} + +// NewProtocolWithOptions returns an IPv6 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { + opts.NDPConfigs.validate() + + ids := hash.RandN32(buckets) + hashIV := hash.RandN32(1)[0] + + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), + ids: ids, + hashIV: hashIV, + + ndpDisp: opts.NDPDisp, + ndpConfigs: opts.NDPConfigs, + opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + } + p.mu.eps = make(map[*endpoint]struct{}) + p.SetDefaultTTL(DefaultTTL) + return p + } +} + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) +} + +func calculateFragmentReserve(pkt *stack.PacketBuffer) int { + return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address and 32-bit number to generate the hash. +func hashRoute(r *stack.Route, hashIV uint32) uint32 { + // The FNV-1a was chosen because it is a fast hashing algorithm, and + // cryptographic properties are not needed here. + h := fnv.New32a() + if _, err := h.Write([]byte(r.LocalAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) + } + + if _, err := h.Write([]byte(r.RemoteAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) + } + + s := make([]byte, 4) + binary.LittleEndian.PutUint32(s, hashIV) + if _, err := h.Write(s); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err)) + } + + return h.Sum32() +} + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeadersLength := len(originalIPHeaders) + fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + + // Copy the IPv6 header and any extension headers already populated. + if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) + } + fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) + + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) + fragmentHeader.Encode(&header.IPv6FragmentFields{ + M: more, + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + Identification: id, + NextHeader: uint8(transportProto), + }) + + return fragPkt, more } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 841a0cb7a..c593c0004 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,15 +15,22 @@ package ipv6 import ( + "encoding/hex" + "fmt" + "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -43,6 +50,8 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + + extraHeaderReserve = 50 ) // testReceiveICMP tests receiving an ICMP packet from src to dst. want is the @@ -51,8 +60,8 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst t.Helper() // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize)) pkt.SetType(header.ICMPv6NeighborAdvert) pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() @@ -65,9 +74,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().ICMP.V6PacketsReceived @@ -123,9 +132,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stat := s.Stats().UDP.PacketsReceived @@ -134,25 +143,103 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst } } +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { + // sourcePacket does not have its IP Header populated. Let's copy the one + // from the first fragment. + source := header.IPv6(packets[0].NetworkHeader().View()) + sourceIPHeadersLen := len(source) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) + source = append(source, vv.ToView()...) + + var reassembledPayload buffer.VectorisedView + for i, fragment := range packets { + // Confirm that the packet is valid. + allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views()) + fragmentIPHeaders := header.IPv6(allBytes.ToView()) + if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders)) + } + + fragmentIPHeadersLength := fragment.NetworkHeader().View().Size() + if fragmentIPHeadersLength != sourceIPHeadersLen { + return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen) + } + + if got := len(fragmentIPHeaders); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu) + } + + sourceIPHeader := source[:header.IPv6MinimumSize] + fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize] + + if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize { + return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize) + } + + // We expect the IPv6 Header to be similar across each fragment, besides the + // payload length. + sourceIPHeader.SetPayloadLength(0) + fragmentIPHeader.SetPayloadLength(0) + if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) + } + + if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) + } + if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) + } + + if len(packets) > 1 { + // If the source packet was big enough that it needed fragmentation, let's + // inspect the fragment header. Because no other extension headers are + // supported, it will always be the last extension header. + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength]) + + if got := fragmentHeader.More(); got != wantFragments[i].more { + return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more) + } + if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset { + return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset) + } + if got := fragmentHeader.NextHeader(); got != uint8(proto) { + return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto)) + } + } + + // Store the reassembled payload as we parse each fragment. The payload + // includes the Transport header and everything after. + reassembledPayload.AppendView(fragment.TransportHeader().View()) + reassembledPayload.Append(fragment.Data) + } + + if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + + return nil +} + // TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and // UDP packets destined to the IPv6 link-local all-nodes multicast address. func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) + e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } @@ -168,15 +255,13 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // packets destined to the IPv6 solicited-node address of an assigned IPv6 // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - const nicID = 1 - tests := []struct { name string - protocolFactory stack.TransportProtocol + protocolFactory stack.TransportProtocolFactory rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) }{ - {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, - {"UDP", udp.NewProtocol(), testReceiveUDP}, + {"ICMP", icmp.NewProtocol6, testReceiveICMP}, + {"UDP", udp.NewProtocol, testReceiveUDP}, } snmc := header.SolicitedNodeAddr(addr2) @@ -184,16 +269,16 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv6EmptySubnet, NIC: nicID, }, @@ -271,7 +356,7 @@ func TestAddIpv6Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -293,17 +378,22 @@ func TestAddIpv6Address(t *testing.T) { } func TestReceiveIPv6ExtHdrs(t *testing.T) { - const nicID = 1 - tests := []struct { name string extHdr func(nextHdr uint8) ([]byte, uint8) shouldAccept bool + // Should we expect an ICMP response and if so, with what contents? + expectICMP bool + ICMPType header.ICMPv6Type + ICMPCode header.ICMPv6Code + pointer uint32 + multicast bool }{ { name: "None", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, shouldAccept: true, + expectICMP: false, }, { name: "hopbyhop with unknown option skippable action", @@ -334,9 +424,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action", + name: "hopbyhop with unknown option discard and send icmp action (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -346,12 +457,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -362,39 +499,77 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: false, }, { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: true, }, { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 1, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6ErroneousHeader, + pointer: header.IPv6FixedHeaderSize + 2, }, { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 0, 0, 0, 0, + }, fragmentExtHdrID + }, shouldAccept: true, }, { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: true, + expectICMP: false, }, { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: false, + expectICMP: false, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{}, + noNextHdrID + }, shouldAccept: false, + expectICMP: false, }, { name: "destination with unknown option skippable action", @@ -410,6 +585,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: true, + expectICMP: false, }, { name: "destination with unknown option discard action", @@ -425,9 +601,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: false, + expectICMP: false, + }, + { + name: "destination with unknown option discard and send icmp action (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. + }, destinationExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action", + name: "destination with unknown option discard and send icmp action (muilticast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -437,12 +634,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. }, destinationExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action unless multicast dest", + name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -453,22 +656,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "routing - atomic fragment", + name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + nextHdr, 1, - // Fragment extension header. - nextHdr, 0, 0, 0, 1, 2, 3, 4, - }, routingExtHdrID + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. + }, destinationExtHdrID }, - shouldAccept: true, + shouldAccept: false, + expectICMP: false, + multicast: true, }, { name: "atomic fragment - routing", @@ -502,12 +716,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { return []byte{ // Routing extension header. hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. // Hop By Hop extension header with skippable unknown option. nextHdr, 0, 62, 4, 1, 2, 3, 4, }, routingExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { + name: "routing - hop by hop (with send icmp unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. + + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, routingExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, }, { name: "No next header", @@ -551,6 +795,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", @@ -571,16 +816,17 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -588,6 +834,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + // Add a default route so that a return packet knows where to go. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -629,17 +883,21 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: ipv6NextHdr, HopLimit: 255, SrcAddr: addr1, - DstAddr: addr2, + DstAddr: dstAddr, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().UDP.PacketsReceived @@ -648,6 +906,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = 0", got) } + if !test.expectICMP { + if p, ok := e.Read(); ok { + t.Fatalf("unexpected packet received: %#v", p) + } + return + } + + // ICMP required. + p, ok := e.Read() + if !ok { + t.Fatalf("expected packet wasn't written out") + } + + // Pack the output packet into a single buffer.View as the checkers + // assume that. + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { + t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) + } + + ipHdr := header.IPv6(pkt) + checker.IPv6(t, ipHdr, checker.ICMPv6( + checker.ICMPv6Type(test.ICMPType), + checker.ICMPv6Code(test.ICMPCode))) + + // We know we are looking at no extension headers in the error ICMP + // packets. + icmpPkt := header.ICMPv6(ipHdr.Payload()) + // We know we sent small packets that won't be truncated when reflected + // back to us. + originalPacket := icmpPkt.Payload() + if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { + t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) + } + if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { + t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) + } return } @@ -673,20 +969,27 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // fragmentData holds the IPv6 payload for a fragmented IPv6 packet. type fragmentData struct { + srcAddr tcpip.Address + dstAddr tcpip.Address nextHdr uint8 data buffer.VectorisedView } func TestReceiveIPv6Fragments(t *testing.T) { - const nicID = 1 - const udpPayload1Length = 256 - const udpPayload2Length = 128 - const fragmentExtHdrLen = 8 - // Note, not all routing extension headers will be 8 bytes but this test - // uses 8 byte routing extension headers for most sub tests. - const routingExtHdrLen = 8 - - udpGen := func(payload []byte, multiplier uint8) buffer.View { + const ( + udpPayload1Length = 256 + udpPayload2Length = 128 + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 8200 section 4.5). + udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize + fragmentExtHdrLen = 8 + // Note, not all routing extension headers will be 8 bytes but this test + // uses 8 byte routing extension headers for most sub tests. + routingExtHdrLen = 8 + ) + + udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View { payloadLen := len(payload) for i := 0; i < payloadLen; i++ { payload[i] = uint8(i) * multiplier @@ -702,19 +1005,31 @@ func TestReceiveIPv6Fragments(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), payload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength)) sum = header.Checksum(payload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) return hdr.View() } - var udpPayload1Buf [udpPayload1Length]byte - udpPayload1 := udpPayload1Buf[:] - ipv6Payload1 := udpGen(udpPayload1, 1) + var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte + udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:] + ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2) + + var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte + udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:] + ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2) - var udpPayload2Buf [udpPayload2Length]byte - udpPayload2 := udpPayload2Buf[:] - ipv6Payload2 := udpGen(udpPayload2, 2) + var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte + udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:] + ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2) + + var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte + udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] + ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) + + var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte + udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:] + ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2) tests := []struct { name string @@ -726,34 +1041,60 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "No fragmentation", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: uint8(header.UDPProtocolNumber), - data: ipv6Payload1.ToVectorisedView(), + data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Atomic fragment", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), + []buffer.View{ + // Fragment extension header. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + + ipv6Payload1Addr1ToAddr2, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Atomic fragment with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1), + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), []buffer.View{ // Fragment extension header. buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), - ipv6Payload1, + ipv6Payload3Addr1ToAddr2, }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, }, { name: "Two fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -763,31 +1104,189 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], }, ), }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments out of order", + fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with different Next Header values", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + // NextHeader value is different than the one in the first fragment, so + // this NextHeader should be ignored. + buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[:63], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3Addr1ToAddr2[63:], + }, + ), + }, + }, + expectedPayloads: nil, }, { name: "Two fragments with different IDs", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -797,21 +1296,23 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 2 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, @@ -819,9 +1320,49 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: nil, }, { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:65520], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8190, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[65520:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -836,14 +1377,16 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Routing extension header. // @@ -855,17 +1398,19 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Two fragments with per-fragment routing header with non-zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -880,14 +1425,16 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: routingExtHdrID, data: buffer.NewVectorisedView( - routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Routing extension header. // @@ -899,7 +1446,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 9, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, @@ -910,6 +1457,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with routing header with zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -924,31 +1473,35 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Segments left = 0. buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { name: "Two fragments with routing header with non-zero segments left", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( routingExtHdrLen+fragmentExtHdrLen+64, @@ -963,21 +1516,23 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Segments left = 1. buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, @@ -988,6 +1543,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with routing header with zero segments left across fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is fragmentExtHdrLen+8 because the @@ -1008,12 +1565,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1), + fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), []buffer.View{ // Fragment extension header. // @@ -1023,7 +1582,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header (part 2) buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - ipv6Payload1, + ipv6Payload1Addr1ToAddr2, }, ), }, @@ -1034,6 +1593,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with routing header with non-zero segments left across fragments", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is fragmentExtHdrLen+8 because the @@ -1054,12 +1615,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( // The length of this payload is - // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of // the 16 byte routing extension header is in this fagment. - fragmentExtHdrLen+8+len(ipv6Payload1), + fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2), []buffer.View{ // Fragment extension header. // @@ -1069,7 +1632,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header (part 2) buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), - ipv6Payload1, + ipv6Payload1Addr1ToAddr2, }, ), }, @@ -1082,6 +1645,8 @@ func TestReceiveIPv6Fragments(t *testing.T) { name: "Two fragments with atomic", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -1091,47 +1656,53 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, // This fragment has the same ID as the other fragments but is an atomic // fragment. It should not interfere with the other fragments. { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2), + fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2), []buffer.View{ // Fragment extension header. // // Fragment offset = 0, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), - ipv6Payload2, + ipv6Payload2Addr1ToAddr2, }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload2, udpPayload1}, + expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2}, }, { name: "Two interleaved fragmented packets", fragments: []fragmentData{ { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+64, @@ -1141,11 +1712,13 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload1[:64], + ipv6Payload1Addr1ToAddr2[:64], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( fragmentExtHdrLen+32, @@ -1155,50 +1728,124 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 0, More = true, ID = 2 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), - ipv6Payload2[:32], + ipv6Payload2Addr1ToAddr2[:32], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload1)-64, + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, []buffer.View{ // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), - ipv6Payload1[64:], + ipv6Payload1Addr1ToAddr2[64:], }, ), }, { + srcAddr: addr1, + dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload2)-32, + fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32, []buffer.View{ // Fragment extension header. // // Fragment offset = 4, More = false, ID = 2 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), - ipv6Payload2[32:], + ipv6Payload2Addr1ToAddr2[32:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2}, + }, + { + name: "Two interleaved fragmented packets from different sources but with same ID", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr3, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr3ToAddr2[:32], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + { + srcAddr: addr3, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 4, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), + + ipv6Payload1Addr3ToAddr2[32:], }, ), }, }, - expectedPayloads: [][]byte{udpPayload1, udpPayload2}, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1231,16 +1878,16 @@ func TestReceiveIPv6Fragments(t *testing.T) { PayloadLength: uint16(f.data.Size()), NextHeader: f.nextHdr, HopLimit: 255, - SrcAddr: addr1, - DstAddr: addr2, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, }) vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { @@ -1263,3 +1910,920 @@ func TestReceiveIPv6Fragments(t *testing.T) { }) } } + +func TestInvalidIPv6Fragments(t *testing.T) { + const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_INVALID_IPV6_FRAGMENTS" + ) + + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + expectICMP bool + expectICMPType header.ICMPv6Type + expectICMPCode header.ICMPv6Code + expectICMPTypeSpecific uint32 + }{ + { + name: "fragment size is not a multiple of 8 and the M flag is true", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 9, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0 >> 3, + M: true, + Identification: ident, + }, + payload: []byte(data)[:9], + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6PayloadLenOffset, + }, + { + name: "fragments reassembled into a payload exceeding the max IPv6 payload size", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, + M: false, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + }) + e := channel.New(1, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + + var expectICMPPayload buffer.View + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if test.expectICMP { + expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { + t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { + t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) + } + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), + checker.ICMPv6( + checker.ICMPv6Type(test.expectICMPType), + checker.ICMPv6Code(test.expectICMPCode), + checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), + checker.ICMPv6Payload([]byte(expectICMPPayload)), + ), + ) + }) + } +} + +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + Clock: clock, + }) + + e := channel.New(1, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) + } + + clock.Advance(ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), + checker.ICMPv6Payload([]byte(firstFragmentSent)), + ), + ) + }) + } +} + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } + const ( + src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ) + if err := s.AddAddress(1, ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) + } + { + mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") + subnet, err := tcpip.NewSubnet(dst, mask) + if err != nil { + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err) + } + return rt +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if !hasEP { + t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep) + } + } + + ep.Close() + + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep) + } + } +} + +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 +} + +var fragmentationTests = []struct { + description string + mtu uint32 + gso *stack.GSO + transHdrLen int + payloadSize int + wantFragments []fragmentInfo +}{ + { + description: "No fragmentation", + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 0, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv6MinimumMTU + 1, + gso: nil, + transHdrLen: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: nil, + transHdrLen: 100, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso none", + mtu: header.IPv6MinimumMTU, + gso: &stack.GSO{Type: stack.GSONone}, + transHdrLen: 0, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 176, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 100, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 76, more: false}, + }, + }, +} + +func TestFragmentationWritePacket(t *testing.T) { + const ( + ttl = 42 + tos = stack.DefaultTOS + transportProto = tcp.ProtocolNumber + ) + + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt.Clone() + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != nil { + t.Fatalf("WritePacket(_, _, _): = %s", err) + } + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } +} + +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + tests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, + } + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if n != wantTotalPackets || err != nil { + t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } + }) + } +} + +// TestFragmentationErrors checks that errors are returned from WritePacket +// correctly. +func TestFragmentationErrors(t *testing.T) { + const ttl = 42 + + tests := []struct { + description string + mtu uint32 + transHdrLen int + payloadSize int + allowPackets int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error + }{ + { + description: "No frag", + mtu: 2000, + payloadSize: 1000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on first frag", + mtu: 1300, + payloadSize: 3000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on second frag", + mtu: 1500, + payloadSize: 4000, + transHdrLen: 0, + allowPackets: 1, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error when MTU is smaller than transport header", + mtu: header.IPv6MinimumMTU, + transHdrLen: 1500, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrMessageTooLong, + }, + { + description: "Error when MTU is smaller than IPv6 minimum MTU", + mtu: header.IPv6MinimumMTU - 1, + transHdrLen: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, + }, + } + + for _, ft := range tests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go new file mode 100644 index 000000000..40da011f8 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -0,0 +1,2013 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "fmt" + "log" + "math/rand" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor + // Solicitation messages to send when doing Duplicate Address Detection + // for a tentative address. + // + // Default = 1 (from RFC 4862 section 5.1) + defaultDupAddrDetectTransmits = 1 + + // defaultMaxRtrSolicitations is the default number of Router + // Solicitation messages to send when an IPv6 endpoint becomes enabled. + // + // Default = 3 (from RFC 4861 section 10). + defaultMaxRtrSolicitations = 3 + + // defaultRtrSolicitationInterval is the default amount of time between + // sending Router Solicitation messages. + // + // Default = 4s (from 4861 section 10). + defaultRtrSolicitationInterval = 4 * time.Second + + // defaultMaxRtrSolicitationDelay is the default maximum amount of time + // to wait before sending the first Router Solicitation message. + // + // Default = 1s (from 4861 section 10). + defaultMaxRtrSolicitationDelay = time.Second + + // defaultHandleRAs is the default configuration for whether or not to + // handle incoming Router Advertisements as a host. + defaultHandleRAs = true + + // defaultDiscoverDefaultRouters is the default configuration for + // whether or not to discover default routers from incoming Router + // Advertisements, as a host. + defaultDiscoverDefaultRouters = true + + // defaultDiscoverOnLinkPrefixes is the default configuration for + // whether or not to discover on-link prefixes from incoming Router + // Advertisements' Prefix Information option, as a host. + defaultDiscoverOnLinkPrefixes = true + + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + + // minimumRtrSolicitationInterval is the minimum amount of time to wait + // between sending Router Solicitation messages. This limit is imposed + // to make sure that Router Solicitation messages are not sent all at + // once, defeating the purpose of sending the initial few messages. + minimumRtrSolicitationInterval = 500 * time.Millisecond + + // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait + // before sending the first Router Solicitation message. It is 0 because + // we cannot have a negative delay. + minimumMaxRtrSolicitationDelay = 0 + + // MaxDiscoveredDefaultRouters is the maximum number of discovered + // default routers. The stack should stop discovering new routers after + // discovering MaxDiscoveredDefaultRouters routers. + // + // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and + // SHOULD be more. + MaxDiscoveredDefaultRouters = 10 + + // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered + // on-link prefixes. The stack should stop discovering new on-link + // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link + // prefixes. + MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 + + // defaultAutoGenTempGlobalAddresses is the default configuration for whether + // or not to generate temporary SLAAC addresses. + defaultAutoGenTempGlobalAddresses = true + + // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 7 days (from RFC 4941 section 5). + defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour + + // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 1 day (from RFC 4941 section 5). + defaultMaxTempAddrPreferredLifetime = 24 * time.Hour + + // defaultRegenAdvanceDuration is the default duration before the deprecation + // of a temporary address when a new address will be generated. + // + // Default = 5s (from RFC 4941 section 5). + defaultRegenAdvanceDuration = 5 * time.Second + + // minRegenAdvanceDuration is the minimum duration before the deprecation + // of a temporary address when a new address will be generated. + minRegenAdvanceDuration = time.Duration(0) + + // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt + // SLAAC address regenerations in response to an IPv6 endpoint-local conflict. + maxSLAACAddrLocalRegenAttempts = 10 +) + +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour + + // MaxDesyncFactor is the upper bound for the preferred lifetime's desync + // factor for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Must be greater than 0. + // + // Max = 10m (from RFC 4941 section 5). + MaxDesyncFactor = 10 * time.Minute + + // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the + // maximum preferred lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address is preferred for at + // least 1hr if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour + + // MinMaxTempAddrValidLifetime is the minimum value allowed for the + // maximum valid lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address is valid for at least + // 2hrs if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrValidLifetime = 2 * time.Hour +) + +// NDPEndpoint is an endpoint that supports NDP. +type NDPEndpoint interface { + // SetNDPConfigurations sets the NDP configurations. + SetNDPConfigurations(NDPConfigurations) +} + +// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an +// NDP Router Advertisement informed the Stack about. +type DHCPv6ConfigurationFromNDPRA int + +const ( + _ DHCPv6ConfigurationFromNDPRA = iota + + // DHCPv6NoConfiguration indicates that no configurations are available via + // DHCPv6. + DHCPv6NoConfiguration + + // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. + // + // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 + // returns all available configuration information when serving addresses. + DHCPv6ManagedAddress + + // DHCPv6OtherConfigurations indicates that other configuration information is + // available via DHCPv6. + // + // Other configurations are configurations other than addresses. Examples of + // other configurations are recursive DNS server list, DNS search lists and + // default gateway. + DHCPv6OtherConfigurations +) + +// NDPDispatcher is the interface integrators of netstack must implement to +// receive and handle NDP related events. +type NDPDispatcher interface { + // OnDuplicateAddressDetectionStatus is called when the DAD process for an + // address (addr) on a NIC (with ID nicID) completes. resolved is set to true + // if DAD completed successfully (no duplicate addr detected); false otherwise + // (addr was detected to be a duplicate on the link the NIC is a part of, or + // it was stopped for some other reason, such as the address being removed). + // If an error occured during DAD, err is set and resolved must be ignored. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + + // OnDefaultRouterDiscovered is called when a new default router is + // discovered. Implementations must return true if the newly discovered + // router should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool + + // OnDefaultRouterInvalidated is called when a discovered default router that + // was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) + + // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. + // Implementations must return true if the newly discovered on-link prefix + // should be remembered. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool + + // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that + // was remembered is invalidated. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + + // OnAutoGenAddress is called when a new prefix with its autonomous address- + // configuration flag set is received and SLAAC was performed. Implementations + // may prevent the stack from assigning the address to the NIC by returning + // false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) + // is deprecated, but is still considered valid. Note, if an address is + // invalidated at the same ime it is deprecated, the deprecation event may not + // be received. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnAutoGenAddressInvalidated is called when an auto-generated address + // (SLAAC) is invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption is called when the stack learns of DNS servers + // through NDP. Note, the addresses may contain link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + + // OnDNSSearchListOption is called when the stack learns of DNS search lists + // through NDP. + // + // It is up to the caller to use the domain names in the search list + // for only their valid lifetime. OnDNSSearchListOption may be called + // with new or already known domain names. If called with known domain + // names, their valid lifetimes must be refreshed to lifetime (it may + // be increased, decreased or completely invalidated when lifetime = 0. + OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) + + // OnDHCPv6Configuration is called with an updated configuration that is + // available via DHCPv6 for the passed NIC. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) +} + +// NDPConfigurations is the NDP configurations for the netstack. +type NDPConfigurations struct { + // The number of Neighbor Solicitation messages to send when doing + // Duplicate Address Detection for a tentative address. + // + // Note, a value of zero effectively disables DAD. + DupAddrDetectTransmits uint8 + + // The amount of time to wait between sending Neighbor solicitation + // messages. + // + // Must be greater than or equal to 1ms. + RetransmitTimer time.Duration + + // The number of Router Solicitation messages to send when the IPv6 endpoint + // becomes enabled. + MaxRtrSolicitations uint8 + + // The amount of time between transmitting Router Solicitation messages. + // + // Must be greater than or equal to 0.5s. + RtrSolicitationInterval time.Duration + + // The maximum amount of time before transmitting the first Router + // Solicitation message. + // + // Must be greater than or equal to 0s. + MaxRtrSolicitationDelay time.Duration + + // HandleRAs determines whether or not Router Advertisements are processed. + HandleRAs bool + + // DiscoverDefaultRouters determines whether or not default routers are + // discovered from Router Advertisements, as per RFC 4861 section 6. This + // configuration is ignored if HandleRAs is false. + DiscoverDefaultRouters bool + + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are + // discovered from Router Advertisements' Prefix Information option, as per + // RFC 4861 section 6. This configuration is ignored if HandleRAs is false. + DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs + // SLAAC to auto-generate global SLAAC addresses in response to Prefix + // Information options, as per RFC 4862. + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool + + // AutoGenAddressConflictRetries determines how many times to attempt to retry + // generation of a permanent auto-generated address in response to DAD + // conflicts. + // + // If the method used to generate the address does not support creating + // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's + // MAC address), then no attempt is made to resolve the conflict. + AutoGenAddressConflictRetries uint8 + + // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC + // addresses are generated for an IPv6 endpoint as part of SLAAC privacy + // extensions, as per RFC 4941. + // + // Ignored if AutoGenGlobalAddresses is false. + AutoGenTempGlobalAddresses bool + + // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary + // SLAAC addresses. + MaxTempAddrValidLifetime time.Duration + + // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for + // temporary SLAAC addresses. + MaxTempAddrPreferredLifetime time.Duration + + // RegenAdvanceDuration is the duration before the deprecation of a temporary + // address when a new address will be generated. + RegenAdvanceDuration time.Duration +} + +// DefaultNDPConfigurations returns an NDPConfigurations populated with +// default values. +func DefaultNDPConfigurations() NDPConfigurations { + return NDPConfigurations{ + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + MaxRtrSolicitations: defaultMaxRtrSolicitations, + RtrSolicitationInterval: defaultRtrSolicitationInterval, + MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses, + MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime, + RegenAdvanceDuration: defaultRegenAdvanceDuration, + } +} + +// validate modifies an NDPConfigurations with valid values. If invalid values +// are present in c, the corresponding default values are used instead. +func (c *NDPConfigurations) validate() { + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } + + if c.RtrSolicitationInterval < minimumRtrSolicitationInterval { + c.RtrSolicitationInterval = defaultRtrSolicitationInterval + } + + if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { + c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay + } + + if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime { + c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime + } + + if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime { + c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime + } + + if c.RegenAdvanceDuration < minRegenAdvanceDuration { + c.RegenAdvanceDuration = minRegenAdvanceDuration + } +} + +// ndpState is the per-interface NDP state. +type ndpState struct { + // The IPv6 endpoint this ndpState is for. + ep *endpoint + + // configs is the per-interface NDP configurations. + configs NDPConfigurations + + // The DAD state to send the next NS message, or resolve the address. + dad map[tcpip.Address]dadState + + // The default routers discovered through Router Advertisements. + defaultRouters map[tcpip.Address]defaultRouterState + + rtrSolicit struct { + // The timer used to send the next router solicitation message. + timer tcpip.Timer + + // Used to let the Router Solicitation timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the IPv6 endpoint this ndpState is associated with. MUST be set when the + // timer is set. + done *bool + } + + // The on-link prefixes discovered through Router Advertisements' Prefix + // Information option. + onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The SLAAC prefixes discovered through Router Advertisements' Prefix + // Information option. + slaacPrefixes map[tcpip.Subnet]slaacPrefixState + + // The last learned DHCPv6 configuration from an NDP RA. + dhcpv6Configuration DHCPv6ConfigurationFromNDPRA + + // temporaryIIDHistory is the history value used to generate a new temporary + // IID. + temporaryIIDHistory [header.IIDSize]byte + + // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for + // temporary SLAAC addresses. + temporaryAddressDesyncFactor time.Duration +} + +// dadState holds the Duplicate Address Detection timer and channel to signal +// to the DAD goroutine that DAD should stop. +type dadState struct { + // The DAD timer to send the next NS message, or resolve the address. + timer tcpip.Timer + + // Used to let the DAD timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the IPv6 endpoint this dadState is associated with. + done *bool +} + +// defaultRouterState holds data associated with a default router discovered by +// a Router Advertisement (RA). +type defaultRouterState struct { + // Job to invalidate the default router. + // + // Must not be nil. + invalidationJob *tcpip.Job +} + +// onLinkPrefixState holds data associated with an on-link prefix discovered by +// a Router Advertisement's Prefix Information option (PI) when the NDP +// configurations was configured to do so. +type onLinkPrefixState struct { + // Job to invalidate the on-link prefix. + // + // Must not be nil. + invalidationJob *tcpip.Job +} + +// tempSLAACAddrState holds state associated with a temporary SLAAC address. +type tempSLAACAddrState struct { + // Job to deprecate the temporary SLAAC address. + // + // Must not be nil. + deprecationJob *tcpip.Job + + // Job to invalidate the temporary SLAAC address. + // + // Must not be nil. + invalidationJob *tcpip.Job + + // Job to regenerate the temporary SLAAC address. + // + // Must not be nil. + regenJob *tcpip.Job + + createdAt time.Time + + // The address's endpoint. + // + // Must not be nil. + addressEndpoint stack.AddressEndpoint + + // Has a new temporary SLAAC address already been regenerated? + regenerated bool +} + +// slaacPrefixState holds state associated with a SLAAC prefix. +type slaacPrefixState struct { + // Job to deprecate the prefix. + // + // Must not be nil. + deprecationJob *tcpip.Job + + // Job to invalidate the prefix. + // + // Must not be nil. + invalidationJob *tcpip.Job + + // Nonzero only when the address is not valid forever. + validUntil time.Time + + // Nonzero only when the address is not preferred forever. + preferredUntil time.Time + + // State associated with the stable address generated for the prefix. + stableAddr struct { + // The address's endpoint. + // + // May only be nil when the address is being (re-)generated. Otherwise, + // must not be nil as all SLAAC prefixes must have a stable address. + addressEndpoint stack.AddressEndpoint + + // The number of times an address has been generated locally where the IPv6 + // endpoint already had the generated address. + localGenerationFailures uint8 + } + + // The temporary (short-lived) addresses generated for the SLAAC prefix. + tempAddrs map[tcpip.Address]tempSLAACAddrState + + // The next two fields are used by both stable and temporary addresses + // generated for a SLAAC prefix. This is safe as only 1 address is in the + // generation and DAD process at any time. That is, no two addresses are + // generated at the same time for a given SLAAC prefix. + + // The number of times an address has been generated and added to the IPv6 + // endpoint. + // + // Addresses may be regenerated in reseponse to a DAD conflicts. + generationAttempts uint8 + + // The maximum number of times to attempt regeneration of a SLAAC address + // in response to DAD conflicts. + maxGenerationAttempts uint8 +} + +// startDuplicateAddressDetection performs Duplicate Address Detection. +// +// This function must only be called by IPv6 addresses that are currently +// tentative. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { + // addr must be a valid unicast IPv6 address. + if !header.IsV6UnicastAddress(addr) { + return tcpip.ErrAddressFamilyNotSupported + } + + if addressEndpoint.GetKind() != stack.PermanentTentative { + // The endpoint should be marked as tentative since we are starting DAD. + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID())) + } + + // Should not attempt to perform DAD on an address that is currently in the + // DAD process. + if _, ok := ndp.dad[addr]; ok { + // Should never happen because we should only ever call this function for + // newly created addresses. If we attemped to "add" an address that already + // existed, we would get an error since we attempted to add a duplicate + // address, or its reference count would have been increased without doing + // the work that would have been done for an address that was brand new. + // See endpoint.addAddressLocked. + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID())) + } + + remaining := ndp.configs.DupAddrDetectTransmits + if remaining == 0 { + addressEndpoint.SetKind(stack.Permanent) + + // Consider DAD to have resolved even if no DAD messages were actually + // transmitted. + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) + } + + return nil + } + + var done bool + var timer tcpip.Timer + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { + ndp.ep.mu.Lock() + defer ndp.ep.mu.Unlock() + + if done { + // If we reach this point, it means that the DAD timer fired after + // another goroutine already obtained the IPv6 endpoint lock and stopped + // DAD before this function obtained the NIC lock. Simply return here and + // do nothing further. + return + } + + if addressEndpoint.GetKind() != stack.PermanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) + } + + dadDone := remaining == 0 + + var err *tcpip.Error + if !dadDone { + // Use the unspecified address as the source address when performing DAD. + addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) + + // Do not hold the lock when sending packets which may be a long running + // task or may block link address resolution. We know this is safe + // because immediately after obtaining the lock again, we check if DAD + // has been stopped before doing any work with the IPv6 endpoint. Note, + // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if + // the address was removed. + ndp.ep.mu.Unlock() + err = ndp.sendDADPacket(addr, addressEndpoint) + ndp.ep.mu.Lock() + addressEndpoint.DecRef() + } + + if done { + // If we reach this point, it means that DAD was stopped after we released + // the IPv6 endpoint's read lock and before we obtained the write lock. + return + } + + if dadDone { + // DAD has resolved. + addressEndpoint.SetKind(stack.Permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. + remaining-- + timer.Reset(ndp.configs.RetransmitTimer) + return + } + + // At this point we know that either DAD is done or we hit an error sending + // the last NDP NS. Either way, clean up addr's DAD state and let the + // integrator know DAD has completed. + delete(ndp.dad, addr) + + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) + } + + // If DAD resolved for a stable SLAAC address, attempt generation of a + // temporary SLAAC address. + if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } + }) + + ndp.dad[addr] = dadState{ + timer: timer, + done: &done, + } + + return nil +} + +// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns +// addr. +// +// addr must be a tentative IPv6 address on ndp's IPv6 endpoint. +// +// The IPv6 endpoint that ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { + snmc := header.SolicitedNodeAddr(addr) + + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // Route should resolve immediately since snmc is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the IPv6 endpoint is not + // locked, the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { + return err + } + + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) + } else if c != nil { + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) + } + + icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmpData.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) + ns.SetTargetAddress(addr) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) + + sent := r.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil, + stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt, + ); err != nil { + sent.Dropped.Increment() + return err + } + sent.NeighborSolicit.Increment() + + return nil +} + +// stopDuplicateAddressDetection ends a running Duplicate Address Detection +// process. Note, this may leave the DAD process for a tentative address in +// such a state forever, unless some other external event resolves the DAD +// process (receiving an NA from the true owner of addr, or an NS for addr +// (implying another node is attempting to use addr)). It is up to the caller +// of this function to handle such a scenario. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { + dad, ok := ndp.dad[addr] + if !ok { + // Not currently performing DAD on addr, just return. + return + } + + if dad.timer != nil { + dad.timer.Stop() + dad.timer = nil + + *dad.done = true + dad.done = nil + } + + delete(ndp.dad, addr) + + // Let the integrator know DAD did not resolve. + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) + } +} + +// handleRA handles a Router Advertisement message that arrived on the NIC +// this ndp is for. Does nothing if the NIC is configured to not handle RAs. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + // Is the IPv6 endpoint configured to handle RAs at all? + // + // Currently, the stack does not determine router interface status on a + // per-interface basis; it is a protocol-wide configuration, so we check the + // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding + // packets. + if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { + return + } + + // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we + // only inform the dispatcher on configuration changes. We do nothing else + // with the information. + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + var configuration DHCPv6ConfigurationFromNDPRA + switch { + case ra.ManagedAddrConfFlag(): + configuration = DHCPv6ManagedAddress + + case ra.OtherConfFlag(): + configuration = DHCPv6OtherConfigurations + + default: + configuration = DHCPv6NoConfiguration + } + + if ndp.dhcpv6Configuration != configuration { + ndp.dhcpv6Configuration = configuration + ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration) + } + } + + // Is the IPv6 endpoint configured to discover default routers? + if ndp.configs.DiscoverDefaultRouters { + rtr, ok := ndp.defaultRouters[ip] + rl := ra.RouterLifetime() + switch { + case !ok && rl != 0: + // This is a new default router we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredDefaultRouters routers. + if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { + ndp.rememberDefaultRouter(ip, rl) + } + + case ok && rl != 0: + // This is an already discovered default router. Update + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) + ndp.defaultRouters[ip] = rtr + + case ok && rl == 0: + // We know about the router but it is no longer to be + // used as a default router so invalidate it. + ndp.invalidateDefaultRouter(ip) + } + } + + // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter + // Discovery. + + // We know the options is valid as far as wire format is concerned since + // we got the Router Advertisement, as documented by this fn. Given this + // we do not check the iterator for errors on calls to Next. + it, _ := ra.Options().Iter(false) + for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.ep.protocol.ndpDisp == nil { + continue + } + + addrs, _ := opt.Addresses() + ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) + + case header.NDPDNSSearchList: + if ndp.ep.protocol.ndpDisp == nil { + continue + } + + domainNames, _ := opt.DomainNames() + ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) + + case header.NDPPrefixInformation: + prefix := opt.Subnet() + + // Is the prefix a link-local? + if header.IsV6LinkLocalAddress(prefix.ID()) { + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). + continue + } + + // Is the Prefix Length 0? + if prefix.Prefix() == 0 { + // ...Yes, skip as this is an invalid prefix + // as all IPv6 addresses cannot be on-link. + continue + } + + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) + } + + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) + } + } + + // TODO(b/141556115): Do (MTU) Parameter Discovery. + } +} + +// invalidateDefaultRouter invalidates a discovered default router. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { + rtr, ok := ndp.defaultRouters[ip] + + // Is the router still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + rtr.invalidationJob.Cancel() + delete(ndp.defaultRouters, ip) + + // Let the integrator know a discovered default router is invalidated. + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) + } +} + +// rememberDefaultRouter remembers a newly discovered default router with IPv6 +// link-local address ip with lifetime rl. +// +// The router identified by ip MUST NOT already be known by the IPv6 endpoint. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { + ndpDisp := ndp.ep.protocol.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered a default router. + if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { + // Informed by the integrator to not remember the router, do + // nothing further. + return + } + + state := defaultRouterState{ + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + ndp.invalidateDefaultRouter(ip) + }), + } + + state.invalidationJob.Schedule(rl) + + ndp.defaultRouters[ip] = state +} + +// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 +// address with prefix prefix with lifetime l. +// +// The prefix identified by prefix MUST NOT already be known. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { + ndpDisp := ndp.ep.protocol.ndpDisp + if ndpDisp == nil { + return + } + + // Inform the integrator when we discovered an on-link prefix. + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { + // Informed by the integrator to not remember the prefix, do + // nothing further. + return + } + + state := onLinkPrefixState{ + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + ndp.invalidateOnLinkPrefix(prefix) + }), + } + + if l < header.NDPInfiniteLifetime { + state.invalidationJob.Schedule(l) + } + + ndp.onLinkPrefixes[prefix] = state +} + +// invalidateOnLinkPrefix invalidates a discovered on-link prefix. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { + s, ok := ndp.onLinkPrefixes[prefix] + + // Is the on-link prefix still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + s.invalidationJob.Cancel() + delete(ndp.onLinkPrefixes, prefix) + + // Let the integrator know a discovered on-link prefix is invalidated. + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) + } +} + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // + // Update the invalidation job. + + prefixState.invalidationJob.Cancel() + + if vl < header.NDPInfiniteLifetime { + // Prefix is valid for a finite lifetime, schedule the job to execute after + // the new valid lifetime. + prefixState.invalidationJob.Schedule(vl) + } + + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already maintain SLAAC state for prefix. + if state, ok := ndp.slaacPrefixes[prefix]; ok { + // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. + ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl) + ndp.slaacPrefixes[prefix] = state + return + } + + // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC. + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + ndp.doSLAAC(prefix, pl, vl) +} + +// doSLAAC generates a new SLAAC address with the provided lifetimes +// for prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + state := slaacPrefixState{ + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) + } + + ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint) + }), + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) + } + + ndp.invalidateSLAACPrefix(prefix, state) + }), + tempAddrs: make(map[tcpip.Address]tempSLAACAddrState), + maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, + } + + now := time.Now() + + // The time an address is preferred until is needed to properly generate the + // address. + if pl < header.NDPInfiniteLifetime { + state.preferredUntil = now.Add(pl) + } + + if !ndp.generateSLAACAddr(prefix, &state) { + // We were unable to generate an address for the prefix, we do not nothing + // further as there is no reason to maintain state or jobs for a prefix we + // do not have an address for. + return + } + + // Setup the initial jobs to deprecate and invalidate prefix. + + if pl < header.NDPInfiniteLifetime && pl != 0 { + state.deprecationJob.Schedule(pl) + } + + if vl < header.NDPInfiniteLifetime { + state.invalidationJob.Schedule(vl) + state.validUntil = now.Add(vl) + } + + // If the address is assigned (DAD resolved), generate a temporary address. + if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) + } + + ndp.slaacPrefixes[prefix] = state +} + +// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.ep.protocol.ndpDisp + if ndpDisp == nil { + return nil + } + + if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { + // Informed by the integrator not to add the address. + return nil + } + + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) + if err != nil { + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) + } + + return addressEndpoint +} + +// generateSLAACAddr generates a SLAAC address for prefix. +// +// Returns true if an address was successfully generated. +// +// Panics if the prefix is not a SLAAC prefix or it already has an address. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix())) + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if state.generationAttempts == state.maxGenerationAttempts { + return false + } + + var generatedAddr tcpip.AddressWithPrefix + addrBytes := []byte(prefix.ID()) + + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { + return false + } + + dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures + if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()), + dadCounter, + oIID.SecretKey, + ) + } else if dadCounter == 0 { + // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if + // the DAD counter is non-zero, we cannot use this method. + // + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.ep.nic.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return false + } + + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } else { + // We have no way to regenerate an address in response to an address + // conflict when addresses are not generated with opaque IIDs. + return false + } + + generatedAddr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, + } + + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { + break + } + + state.stableAddr.localGenerationFailures++ + } + + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + state.stableAddr.addressEndpoint = addressEndpoint + state.generationAttempts++ + return true + } + + return false +} + +// regenerateSLAACAddr regenerates an address for a SLAAC prefix. +// +// If generating a new address for the prefix fails, the prefix is invalidated. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix)) + } + + if ndp.generateSLAACAddr(prefix, &state) { + ndp.slaacPrefixes[prefix] = state + return + } + + // We were unable to generate a permanent address for the SLAAC prefix so + // invalidate the prefix as there is no reason to maintain state for a + // SLAAC prefix we do not have an address for. + ndp.invalidateSLAACPrefix(prefix, state) +} + +// generateTempSLAACAddr generates a new temporary SLAAC address. +// +// If resetGenAttempts is true, the prefix's generation counter is reset. +// +// Returns true if a new address was generated. +func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { + // Are we configured to auto-generate new temporary global addresses for the + // prefix? + if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() { + return false + } + + if resetGenAttempts { + prefixState.generationAttempts = 0 + prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1 + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if prefixState.generationAttempts == prefixState.maxGenerationAttempts { + return false + } + + stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address + now := time.Now() + + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. + vl := ndp.configs.MaxTempAddrValidLifetime + if prefixState.validUntil != (time.Time{}) { + if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { + vl = prefixVL + } + } + + if vl <= 0 { + // Cannot create an address without a valid lifetime. + return false + } + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or the + // maximum temporary address preferred lifetime - the temporary address desync + // factor. + pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor + if prefixState.preferredUntil != (time.Time{}) { + if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { + // Respect the preferred lifetime of the prefix, as per RFC 4941 section + // 3.3 step 4. + pl = prefixPL + } + } + + // As per RFC 4941 section 3.3 step 5, a temporary address is created only if + // the calculated preferred lifetime is greater than the advance regeneration + // duration. In particular, we MUST NOT create a temporary address with a zero + // Preferred Lifetime. + if pl <= ndp.configs.RegenAdvanceDuration { + return false + } + + // Attempt to generate a new address that is not already assigned to the IPv6 + // endpoint. + var generatedAddr tcpip.AddressWithPrefix + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { + return false + } + + generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { + break + } + } + + // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary + // address with a zero preferred lifetime. The checks above ensure this + // so we know the address is not deprecated. + addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */) + if addressEndpoint == nil { + return false + } + + state := tempSLAACAddrState{ + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) + } + + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) + }), + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr)) + } + + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) + }), + regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr)) + } + + // If an address has already been regenerated for this address, don't + // regenerate another address. + if tempAddrState.regenerated { + return + } + + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */) + prefixState.tempAddrs[generatedAddr.Address] = tempAddrState + ndp.slaacPrefixes[prefix] = prefixState + }), + createdAt: now, + addressEndpoint: addressEndpoint, + } + + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) + + prefixState.generationAttempts++ + prefixState.tempAddrs[generatedAddr.Address] = state + + return true +} + +// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix)) + } + + ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts) + ndp.slaacPrefixes[prefix] = state +} + +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { + // If the preferred lifetime is zero, then the prefix should be deprecated. + deprecated := pl == 0 + if deprecated { + ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint) + } else { + prefixState.stableAddr.addressEndpoint.SetDeprecated(false) + } + + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() + + now := time.Now() + + // Schedule the deprecation job if prefix has a finite preferred lifetime. + if pl < header.NDPInfiniteLifetime { + if !deprecated { + prefixState.deprecationJob.Schedule(pl) + } + prefixState.preferredUntil = now.Add(pl) + } else { + prefixState.preferredUntil = time.Time{} + } + + // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: + // + // 1) If the received Valid Lifetime is greater than 2 hours or greater than + // RemainingLifetime, set the valid lifetime of the prefix to the + // advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the + // advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. + + if vl >= header.NDPInfiniteLifetime { + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() + prefixState.validUntil = time.Time{} + } else { + var effectiveVl time.Duration + var rl time.Duration + + // If the prefix was originally set to be valid forever, assume the + // remaining time to be the maximum possible value. + if prefixState.validUntil == (time.Time{}) { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(prefixState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl > MinPrefixInformationValidLifetimeForUpdate { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if effectiveVl != 0 { + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) + } + } + + // If DAD is not yet complete on the stable address, there is no need to do + // work with temporary addresses. + if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent { + return + } + + // Note, we do not need to update the entries in the temporary address map + // after updating the jobs because the jobs are held as pointers. + var regenForAddr tcpip.Address + allAddressesRegenerated := true + for tempAddr, tempAddrState := range prefixState.tempAddrs { + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. Note, the valid lifetime of a + // temporary address is relative to the address's creation time. + validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) + if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { + validUntil = prefixState.validUntil + } + + // If the address is no longer valid, invalidate it immediately. Otherwise, + // reset the invalidation job. + newValidLifetime := validUntil.Sub(now) + if newValidLifetime <= 0 { + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) + continue + } + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or + // the maximum temporary address preferred lifetime - the temporary address + // desync factor. Note, the preferred lifetime of a temporary address is + // relative to the address's creation time. + preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) + if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { + preferredUntil = prefixState.preferredUntil + } + + // If the address is no longer preferred, deprecate it immediately. + // Otherwise, schedule the deprecation job again. + newPreferredLifetime := preferredUntil.Sub(now) + tempAddrState.deprecationJob.Cancel() + if newPreferredLifetime <= 0 { + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) + } else { + tempAddrState.addressEndpoint.SetDeprecated(false) + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) + } + + tempAddrState.regenJob.Cancel() + if tempAddrState.regenerated { + } else { + allAddressesRegenerated = false + + if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration { + // The new preferred lifetime is less than the advance regeneration + // duration so regenerate an address for this temporary address + // immediately after we finish iterating over the temporary addresses. + regenForAddr = tempAddr + } else { + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + } + } + } + + // Generate a new temporary address if all of the existing temporary addresses + // have been regenerated, or we need to immediately regenerate an address + // due to an update in preferred lifetime. + // + // If each temporay address has already been regenerated, no new temporary + // address is generated. To ensure continuation of temporary SLAAC addresses, + // we manually try to regenerate an address here. + if len(regenForAddr) != 0 || allAddressesRegenerated { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok { + state.regenerated = true + prefixState.tempAddrs[regenForAddr] = state + } + } +} + +// deprecateSLAACAddress marks the address as deprecated and notifies the NDP +// dispatcher that address has been deprecated. +// +// deprecateSLAACAddress does nothing if the address is already deprecated. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) { + if addressEndpoint.Deprecated() { + return + } + + addressEndpoint.SetDeprecated(true) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) + } +} + +// invalidateSLAACPrefix invalidates a SLAAC prefix. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { + ndp.cleanupSLAACPrefixResources(prefix, state) + + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { + // Since we are already invalidating the prefix, do not invalidate the + // prefix when removing the address. + if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err)) + } + } +} + +// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's +// resources. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) + } + + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address { + return + } + + if !invalidatePrefix { + // If the prefix is not being invalidated, disassociate the address from the + // prefix and do nothing further. + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil + ndp.slaacPrefixes[prefix] = state + return + } + + ndp.cleanupSLAACPrefixResources(prefix, state) +} + +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. +// +// Panics if the SLAAC prefix is not known. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + // Invalidate all temporary addresses. + for tempAddr, tempAddrState := range state.tempAddrs { + ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) + } + + if state.stableAddr.addressEndpoint != nil { + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil + } + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() + delete(ndp.slaacPrefixes, prefix) +} + +// invalidateTempSLAACAddr invalidates a temporary SLAAC address. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + // Since we are already invalidating the address, do not invalidate the + // address when removing the address. + if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err)) + } + + ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary +// SLAAC address's resources from ndp. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) + } + + if !invalidateAddr { + return + } + + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr)) + } + + tempAddrState, ok := state.tempAddrs[addr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr)) + } + + ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's +// jobs and entry. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + tempAddrState.addressEndpoint.DecRef() + tempAddrState.addressEndpoint = nil + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() + delete(tempAddrs, tempAddr) +} + +// removeSLAACAddresses removes all SLAAC addresses. +// +// If keepLinkLocal is false, the SLAAC generated link-local address is removed. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { + linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() + var linkLocalPrefixes int + for prefix, state := range ndp.slaacPrefixes { + // RFC 4862 section 5 states that routers are also expected to generate a + // link-local address so we do not invalidate them if we are cleaning up + // host-only state. + if keepLinkLocal && prefix == linkLocalSubnet { + linkLocalPrefixes++ + continue + } + + ndp.invalidateSLAACPrefix(prefix, state) + } + + if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { + panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) + } +} + +// cleanupState cleans up ndp's state. +// +// If hostOnly is true, then only host-specific state is cleaned up. +// +// This function invalidates all discovered on-link prefixes, discovered +// routers, and auto-generated addresses. +// +// If hostOnly is true, then the link-local auto-generated address aren't +// invalidated as routers are also expected to generate a link-local address. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupState(hostOnly bool) { + ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) + + for prefix := range ndp.onLinkPrefixes { + ndp.invalidateOnLinkPrefix(prefix) + } + + if got := len(ndp.onLinkPrefixes); got != 0 { + panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) + } + + for router := range ndp.defaultRouters { + ndp.invalidateDefaultRouter(router) + } + + if got := len(ndp.defaultRouters); got != 0 { + panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) + } + + ndp.dhcpv6Configuration = 0 +} + +// startSolicitingRouters starts soliciting routers, as per RFC 4861 section +// 6.3.7. If routers are already being solicited, this function does nothing. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) startSolicitingRouters() { + if ndp.rtrSolicit.timer != nil { + // We are already soliciting routers. + return + } + + remaining := ndp.configs.MaxRtrSolicitations + if remaining == 0 { + return + } + + // Calculate the random delay before sending our first RS, as per RFC + // 4861 section 6.3.7. + var delay time.Duration + if ndp.configs.MaxRtrSolicitationDelay > 0 { + delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + } + + var done bool + ndp.rtrSolicit.done = &done + ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { + ndp.ep.mu.Lock() + if done { + // If we reach this point, it means that the RS timer fired after another + // goroutine already obtained the IPv6 endpoint lock and stopped + // solicitations. Simply return here and do nothing further. + ndp.ep.mu.Unlock() + return + } + + // As per RFC 4861 section 4.1, the source of the RS is an address assigned + // to the sending interface, or the unspecified address if no address is + // assigned to the sending interface. + addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) + if addressEndpoint == nil { + // Incase this ends up creating a new temporary address, we need to hold + // onto the endpoint until a route is obtained. If we decrement the + // reference count before obtaing a route, the address's resources would + // be released and attempting to obtain a route after would fail. Once a + // route is obtainted, it is safe to decrement the reference count since + // obtaining a route increments the address's reference count. + addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) + } + ndp.ep.mu.Unlock() + + localAddr := addressEndpoint.AddressWithPrefix().Address + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) + addressEndpoint.DecRef() + if err != nil { + return + } + defer r.Release() + + // Route should resolve immediately since + // header.IPv6AllRoutersMulticastAddress is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the IPv6 endpoint is + // not locked, the IPv6 endpoint could have been disabled or removed by + // another goroutine. + if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { + return + } + + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) + } else if c != nil { + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) + } + + // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source + // link-layer address option if the source address of the NDP RS is + // specified. This option MUST NOT be included if the source address is + // unspecified. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + var optsSerializer header.NDPOptionsSerializer + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + optsSerializer = header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + } + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.NDPPayload()) + rs.Options().Serialize(optsSerializer) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) + + sent := r.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil, + stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt, + ); err != nil { + sent.Dropped.Increment() + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) + // Don't send any more messages if we had an error. + remaining = 0 + } else { + sent.RouterSolicit.Increment() + remaining-- + } + + ndp.ep.mu.Lock() + if done || remaining == 0 { + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil + } else if ndp.rtrSolicit.timer != nil { + // Note, we need to explicitly check to make sure that + // the timer field is not nil because if it was nil but + // we still reached this point, then we know the IPv6 endpoint + // was requested to stop soliciting routers so we don't + // need to send the next Router Solicitation message. + ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) + } + ndp.ep.mu.Unlock() + }) + +} + +// stopSolicitingRouters stops soliciting routers. If routers are not currently +// being solicited, this function does nothing. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) stopSolicitingRouters() { + if ndp.rtrSolicit.timer == nil { + // Nothing to do. + return + } + + *ndp.rtrSolicit.done = true + ndp.rtrSolicit.timer.Stop() + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil +} + +// initializeTempAddrState initializes state related to temporary SLAAC +// addresses. +func (ndp *ndpState) initializeTempAddrState() { + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) + + if MaxDesyncFactor != 0 { + ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 12b70f7e9..ac20f217e 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -15,9 +15,12 @@ package ipv6 import ( + "context" "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -30,12 +33,13 @@ import ( // setupStackAndEndpoint creates a stack with a single NIC with a link-local // address llladdr and an IPv6 endpoint to a remote with link-local address // rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { @@ -63,14 +67,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) } + t.Cleanup(ep.Close) return s, ep } +var _ NDPDispatcher = (*testNDPDispatcher)(nil) + +// testNDPDispatcher is an NDPDispatcher only allows default router discovery. +type testNDPDispatcher struct { + addr tcpip.Address +} + +func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +} + +func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { + t.addr = addr + return true +} + +func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { + t.addr = addr +} + +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { + return false +} + +func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { +} + +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { + return false +} + +func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) { +} + +func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) { +} + +func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) { +} + +func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { + var ndpDisp testNDPDispatcher + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ + NDPDisp: &ndpDisp, + })}, + }) + + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) + } + + ipv6EP := ep.(*endpoint) + ipv6EP.mu.Lock() + ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) + ipv6EP.mu.Unlock() + + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } + + ndpDisp.addr = "" + ndpEP := ep.(stack.NDPEndpoint) + ndpEP.InvalidateDefaultRouter(lladdr1) + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } +} + // TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -100,7 +184,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -136,9 +220,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -174,6 +258,123 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { } } +// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests +// that receiving a valid NDP NS message with the Source Link Layer Address +// option results in a new entry in the link address cache for the sender of +// the message. +func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + neighbors, err := s.Neighbors(nicID) + if err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + } + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + } + neighborByAddr[n.Addr] = n + } + + if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 { + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + if !ok { + t.Fatalf("expected a neighbor entry for %q", lladdr1) + } + if neigh.LinkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr) + } + if neigh.State != stack.Stale { + t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale) + } + } else { + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + + if ok { + t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + } + } + }) + } +} + func TestNeighorSolicitationResponse(t *testing.T) { const nicID = 1 nicAddr := lladdr0 @@ -183,26 +384,41 @@ func TestNeighorSolicitationResponse(t *testing.T) { remoteLinkAddr0 := linkAddr1 remoteLinkAddr1 := linkAddr2 + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + tests := []struct { - name string - nsOpts header.NDPOptionsSerializer - nsSrcLinkAddr tcpip.LinkAddress - nsSrc tcpip.Address - nsDst tcpip.Address - nsInvalid bool - naDstLinkAddr tcpip.LinkAddress - naSolicited bool - naSrc tcpip.Address - naDst tcpip.Address + name string + nsOpts header.NDPOptionsSerializer + nsSrcLinkAddr tcpip.LinkAddress + nsSrc tcpip.Address + nsDst tcpip.Address + nsInvalid bool + naDstLinkAddr tcpip.LinkAddress + naSolicited bool + naSrc tcpip.Address + naDst tcpip.Address + performsLinkResolution bool }{ { - name: "Unspecified source to multicast destination", + name: "Unspecified source to solicited-node multicast destination", nsOpts: nil, nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, nsDst: nicAddrSNMC, nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, + naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), naSolicited: false, naSrc: nicAddr, naDst: header.IPv6AllNodesMulticastAddress, @@ -223,11 +439,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: false, - naSrc: nicAddr, - naDst: header.IPv6AllNodesMulticastAddress, + nsInvalid: true, }, { name: "Unspecified source with source ll option to unicast destination", @@ -239,7 +451,6 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsDst: nicAddr, nsInvalid: true, }, - { name: "Specified source with 1 source ll to multicast destination", nsOpts: header.NDPOptionsSerializer{ @@ -299,6 +510,10 @@ func TestNeighorSolicitationResponse(t *testing.T) { naSolicited: true, naSrc: nicAddr, naDst: remoteAddr, + // Since we send a unicast solicitations to a node without an entry for + // the remote, the node needs to perform neighbor discovery to get the + // remote's link address to send the advertisement response. + performsLinkResolution: true, }, { name: "Specified source with 1 source ll to unicast destination", @@ -341,86 +556,159 @@ func TestNeighorSolicitationResponse(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - e := channel.New(1, 1280, nicLinkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: stackTyp.useNeighborCache, + }) + e := channel.New(1, 1280, nicLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + } - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(nicAddr) + opts := ns.Options() + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) + } - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } + // If we expected the NS to be invalid, we have nothing else to check. + return + } - // If we expected the NS to be invalid, we have nothing else to check. - return - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } + if test.performsLinkResolution { + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NS response") + } + + if p.Route.LocalAddress != nicAddr { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr) + } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + respNSDst := header.SolicitedNodeAddr(test.nsSrc) + if p.Route.RemoteAddress != respNSDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) + } + if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(nicAddr), + checker.DstAddr(respNSDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(test.nsSrc), + checker.NDPNSOptions([]header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(nicLinkAddr), + }), + )) + + ser := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + } + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(test.nsSrc) + na.Options().Serialize(ser) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, + }) + e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } - p, got := e.Read() - if !got { - t.Fatal("expected an NDP NA response") - } + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NA response") + } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) - } + if p.Route.LocalAddress != test.naSrc { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) + } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + if p.Route.RemoteAddress != test.naDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) + } + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) + }) + } }) } } @@ -461,7 +749,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) if err := s.CreateNIC(nicID, e); err != nil { @@ -497,9 +785,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -535,200 +823,385 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { } } -func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r - } - - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - if atomicFragment { - bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength) - bytes[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) - } - - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - ep.HandlePacket(r, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) +// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests +// that receiving a valid NDP NA message with the Target Link Layer Address +// option does not result in a new entry in the neighbor cache for the target +// of the message. +func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { + const nicID = 1 - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + tests := []struct { + name string + optsBuf []byte + isValid bool }{ { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, + name: "Valid", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, + isValid: true, }, { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, + name: "Too Small", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, }, { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, + name: "Invalid Length", + optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, }, { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg + name: "Multiple", + optsBuf: []byte{ + 2, 1, 2, 3, 4, 5, 6, 7, + 2, 1, 2, 3, 4, 5, 6, 8, }, }, } - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code uint8 - valid bool + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr1) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + neighbors, err := s.Neighbors(nicID) + if err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + } + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + } + neighborByAddr[n.Addr] = n + } + + if neigh, ok := neighborByAddr[lladdr1]; ok { + t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + } + + if test.isValid { + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +func TestNDPValidation(t *testing.T) { + stacks := []struct { + name string + useNeighborCache bool }{ { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, + name: "linkAddrCache", + useNeighborCache: false, }, { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, + name: "neighborCache", + useNeighborCache: true, }, } - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + t.Helper() - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetCode(test.code) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } + r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } + return s, ep, r + } - if t.Failed() { - t.FailNow() - } + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View + if atomicFragment { + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr + nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + } - handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload) + len(extensions)), + NextHeader: nextHdr, + HopLimit: hopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } + ep.HandlePacket(r, pkt) + } - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + var sllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool + }{ + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + routerOnly: true, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + extraData: sllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code header.ICMPv6Code + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } + + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" } - }) + + t.Run(name, func(t *testing.T) { + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep, r := setup(t) + defer r.Release() + + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + // RouterOnlyPacketsReceivedByHost count should initially be 0. + if got := routerOnly.Value(); got != 0 { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } + + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 + } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + + want = 0 + if test.valid && !isRouter && typ.routerOnly { + // RouterOnlyPacketsReceivedByHost count should have increased. + want = 1 + } + if got := routerOnly.Value(); got != want { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + } + + }) + } + }) + } } }) } + } // TestRouterAdvertValidation tests that when the NIC is configured to handle // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. func TestRouterAdvertValidation(t *testing.T) { + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + tests := []struct { name string src tcpip.Address hopLimit uint8 - code uint8 + code header.ICMPv6Code ndpPayload []byte expectedSuccess bool }{ @@ -845,61 +1318,67 @@ func TestRouterAdvertValidation(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: stackTyp.useNeighborCache, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + } + }) } }) } diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD new file mode 100644 index 000000000..d0ffc299a --- /dev/null +++ b/pkg/tcpip/network/testutil/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + srcs = [ + "testutil.go", + ], + visibility = [ + "//pkg/tcpip/network/fragmentation:__pkg__", + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go new file mode 100644 index 000000000..7cc52985e --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil.go @@ -0,0 +1,144 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil defines types and functions used to test Network Layer +// functionality such as IP fragmentation. +package testutil + +import ( + "fmt" + "math/rand" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// MockLinkEndpoint is an endpoint used for testing, it stores packets written +// to it and can mock errors. +type MockLinkEndpoint struct { + // WrittenPackets is where packets written to the endpoint are stored. + WrittenPackets []*stack.PacketBuffer + + mtu uint32 + err *tcpip.Error + allowPackets int +} + +// NewMockLinkEndpoint creates a new MockLinkEndpoint. +// +// err is the error that will be returned once allowPackets packets are written +// to the endpoint. +func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { + return &MockLinkEndpoint{ + mtu: mtu, + err: err, + allowPackets: allowPackets, + } +} + +// MTU implements LinkEndpoint.MTU. +func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var n int + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + return n, err + } + n++ + } + + return n, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*MockLinkEndpoint) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*MockLinkEndpoint) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// MakeRandPkt generates a randomized packet. transportHeaderLength indicates +// how many random bytes will be copied in the Transport Header. +// extraHeaderReserveLength indicates how much extra space will be reserved for +// the other headers. The payload is made from Views of the sizes listed in +// viewSizes. +func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { + var views buffer.VectorisedView + + for _, s := range viewSizes { + newView := buffer.NewView(s) + if _, err := rand.Read(newView); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + views.AppendView(newView) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, + Data: views, + }) + pkt.NetworkProtocolNumber = proto + if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt +} |