summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD4
-rw-r--r--pkg/tcpip/network/arp/BUILD3
-rw-r--r--pkg/tcpip/network/arp/arp.go248
-rw-r--r--pkg/tcpip/network/arp/arp_test.go486
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD9
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go256
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go498
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go59
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go27
-rw-r--r--pkg/tcpip/network/ip_test.go1109
-rw-r--r--pkg/tcpip/network/ipv4/BUILD9
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go405
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go1234
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go2436
-rw-r--r--pkg/tcpip/network/ipv6/BUILD9
-rw-r--r--pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go40
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go753
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go1170
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go1434
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go1806
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go2013
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go1081
-rw-r--r--pkg/tcpip/network/testutil/BUILD21
-rw-r--r--pkg/tcpip/network/testutil/testutil.go144
24 files changed, 13324 insertions, 1930 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 6a4839fb8..c118a2929 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -9,13 +9,17 @@ go_test(
"ip_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
],
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index eddf7b725..8a6bcfc2c 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/stack",
],
)
@@ -28,5 +29,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 9d0797af7..a79379abb 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -15,20 +15,16 @@
// Package arp implements the ARP network protocol. It is used to resolve
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
-//
-// To use it in the networking stack, pass arp.NewProtocol() as one of the
-// network protocols when calling stack.New. Then add an "arp" address to every
-// NIC on the stack that should respond to ARP requests. That is:
-//
-// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
-// // handle err
-// }
package arp
import (
+ "fmt"
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -40,53 +36,81 @@ const (
ProtocolAddress = tcpip.Address("arp")
)
-// endpoint implements stack.NetworkEndpoint.
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- protocol *protocol
- nicID tcpip.NICID
- linkEP stack.LinkEndpoint
+ stack.AddressableEndpointState
+
+ protocol *protocol
+
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
}
-// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
-func (e *endpoint) DefaultTTL() uint8 {
- return 0
+func (e *endpoint) Enable() *tcpip.Error {
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ e.setEnabled(true)
+ return nil
}
-func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
- return lmtu - uint32(e.MaxHeaderLength())
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
}
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+// setEnabled sets the enabled status for the endpoint.
+func (e *endpoint) setEnabled(v bool) {
+ if v {
+ atomic.StoreUint32(&e.enabled, 1)
+ } else {
+ atomic.StoreUint32(&e.enabled, 0)
+ }
}
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &stack.NetworkEndpointID{ProtocolAddress}
+func (e *endpoint) Disable() {
+ e.setEnabled(false)
}
-func (e *endpoint) PrefixLen() int {
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
+func (e *endpoint) DefaultTTL() uint8 {
return 0
}
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.nic.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.ARPSize
+ return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return e.protocol.Number()
+ return ProtocolNumber
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
@@ -94,16 +118,16 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- v, ok := pkt.Data.PullUp(header.ARPSize)
- if !ok {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
return
}
- h := header.ARP(v)
+
+ h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
return
}
@@ -111,30 +135,80 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
- return // we have no useful answer, ignore the request
+
+ if e.nud == nil {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ } else {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ remoteAddr := tcpip.Address(h.ProtocolAddressSender())
+ remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
- packet := header.ARP(hdr.Prepend(header.ARPSize))
+
+ respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
+ })
+ packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
- copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
- copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
- copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
- Header: hdr,
- })
- fallthrough // also fill the cache from requests
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress())
+ if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ origSender := h.HardwareAddressSender()
+ if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize))
+ }
+ if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt)
+
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+
+ if e.nud == nil {
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ return
+ }
+
+ // The solicited, override, and isRouter flags are not available for ARP;
+ // they are only available for IPv6 Neighbor Advertisements.
+ e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{
+ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies
+ // are handled equivalently to a solicited Neighbor Advertisement.
+ Solicited: true,
+ // If a different link address is received than the one cached, the entry
+ // should always go to Stale.
+ Override: false,
+ // ARP does not distinguish between router and non-router hosts.
+ IsRouter: false,
+ })
}
}
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
+ stack *stack.Stack
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -146,16 +220,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- if addrWithPrefix.Address != ProtocolAddress {
- return nil, tcpip.ErrBadLocalAddress
- }
- return &endpoint{
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
protocol: p,
- nicID: nicID,
- linkEP: sender,
+ nic: nic,
linkAddrCache: linkAddrCache,
- }, nil
+ nud: nud,
+ }
+ e.AddressableEndpointState.Init(e)
+ return e
}
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
@@ -164,28 +237,50 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
- r := &stack.Route{
- RemoteLinkAddress: broadcastMAC,
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ if len(remoteLinkAddr) == 0 {
+ remoteLinkAddr = header.EthernetBroadcastAddress
}
- hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
- h := header.ARP(hdr.Prepend(header.ARPSize))
- h.SetIPv4OverEthernet()
- h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), linkEP.LinkAddress())
- copy(h.ProtocolAddressSender(), localAddr)
- copy(h.ProtocolAddressTarget(), addr)
+ nicID := nic.ID()
+ if len(localAddr) == 0 {
+ addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ if err != nil {
+ return err
+ }
+
+ if len(addr.Address) == 0 {
+ return tcpip.ErrNetworkUnreachable
+ }
+
+ localAddr = addr.Address
+ } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
- Header: hdr,
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
})
+ h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ pkt.NetworkProtocolNumber = ProtocolNumber
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
- return broadcastMAC, true
+ return header.EthernetBroadcastAddress, true
}
if header.IsV4MulticastAddress(addr) {
return header.EthernetAddressFromMulticastIPv4Address(addr), true
@@ -194,12 +289,12 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
// SetOption implements stack.NetworkProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.NetworkProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -209,9 +304,16 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+// Parse implements stack.NetworkProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ return 0, false, parse.ARP(pkt)
+}
// NewProtocol returns an ARP network protocol.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+//
+// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
+// address must be added to every NIC that should respond to ARP requests. See
+// ProtocolAddress for more details.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 1646d9cde..bf1292bb8 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -16,10 +16,13 @@ package arp_test
import (
"context"
+ "fmt"
"strconv"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -32,54 +35,184 @@ import (
)
const (
+ nicID = 1
+
+ stackAddr = tcpip.Address("\x0a\x00\x00\x01")
stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+
+ remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
+ remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+
+ unknownAddr = tcpip.Address("\x0a\x00\x00\x03")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
+
+ // eventChanSize defines the size of event channels used by the neighbor
+ // cache's event dispatcher. The size chosen here needs to be sufficient to
+ // queue all the events received during tests before consumption.
+ // If eventChanSize is too small, the tests may deadlock.
+ eventChanSize = 32
+)
+
+type eventType uint8
+
+const (
+ entryAdded eventType = iota
+ entryChanged
+ entryRemoved
)
+func (t eventType) String() string {
+ switch t {
+ case entryAdded:
+ return "add"
+ case entryChanged:
+ return "change"
+ case entryRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+type eventInfo struct {
+ eventType eventType
+ nicID tcpip.NICID
+ entry stack.NeighborEntry
+}
+
+func (e eventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
+}
+
+// arpDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type arpDispatcher struct {
+ // C is where events are queued
+ C chan eventInfo
+}
+
+var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
+
+func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryChanged,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryRemoved,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
+ select {
+ case got := <-d.C:
+ if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
+ return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
+ }
+ case <-ctx.Done():
+ return fmt.Errorf("%s for %s", ctx.Err(), want)
+ }
+ return nil
+}
+
+func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ return d.waitForEvent(ctx, want)
+}
+
+func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
+ select {
+ case event := <-d.C:
+ return event, true
+ default:
+ return eventInfo{}, false
+ }
+}
+
type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
+ s *stack.Stack
+ linkEP *channel.Endpoint
+ nudDisp *arpDispatcher
}
-func newTestContext(t *testing.T) *testContext {
+func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
+ c := stack.DefaultNUDConfigurations()
+ // Transition from Reachable to Stale almost immediately to test if receiving
+ // probes refreshes positive reachability.
+ c.BaseReachableTime = time.Microsecond
+
+ d := arpDispatcher{
+ // Create an event channel large enough so the neighbor cache doesn't block
+ // while dispatching events. Blocking could interfere with the timing of
+ // NUD transitions.
+ C: make(chan eventInfo, eventChanSize),
+ }
+
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ NUDConfigs: c,
+ NUDDisp: &d,
+ UseNeighborCache: useNeighborCache,
})
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNIC(nicID, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ if !useNeighborCache {
+ // The remote address needs to be assigned to the NIC so we can receive and
+ // verify outgoing ARP packets. The neighbor cache isn't concerned with
+ // this; the tests that use linkAddrCache expect the ARP responses to be
+ // received by the same NIC.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
}
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
t.Fatalf("AddAddress for arp failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
}})
return &testContext{
- t: t,
- s: s,
- linkEP: ep,
+ s: s,
+ linkEP: ep,
+ nudDisp: &d,
}
}
@@ -88,7 +221,7 @@ func (c *testContext) cleanup() {
}
func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
+ c := newTestContext(t, false /* useNeighborCache */)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
@@ -103,21 +236,21 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
- })
+ }))
}
- for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ for i, address := range []tcpip.Address{stackAddr, remoteAddr} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
pi, _ := c.linkEP.ReadContext(context.Background())
if pi.Proto != arp.ProtocolNumber {
t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
- rep := header.ARP(pi.Pkt.Header.View())
+ rep := header.ARP(pi.Pkt.NetworkHeader().View())
if !rep.IsValid() {
- t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
+ t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
}
if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
@@ -134,7 +267,7 @@ func TestDirectRequest(t *testing.T) {
})
}
- inject(stackAddrBad)
+ inject(unknownAddr)
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
@@ -144,3 +277,302 @@ func TestDirectRequest(t *testing.T) {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
}
+
+func TestDirectRequestWithNeighborCache(t *testing.T) {
+ c := newTestContext(t, true /* useNeighborCache */)
+ defer c.cleanup()
+
+ tests := []struct {
+ name string
+ senderAddr tcpip.Address
+ senderLinkAddr tcpip.LinkAddress
+ targetAddr tcpip.Address
+ isValid bool
+ }{
+ {
+ name: "Loopback",
+ senderAddr: stackAddr,
+ senderLinkAddr: stackLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "Remote",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "RemoteInvalidTarget",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: unknownAddr,
+ isValid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Inject an incoming ARP request.
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), test.senderLinkAddr)
+ copy(h.ProtocolAddressSender(), test.senderAddr)
+ copy(h.ProtocolAddressTarget(), test.targetAddr)
+ c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
+
+ if !test.isValid {
+ // No packets should be sent after receiving an invalid ARP request.
+ // There is no need to perform a blocking read here, since packets are
+ // sent in the same function that handles ARP requests.
+ if pkt, ok := c.linkEP.Read(); ok {
+ t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto)
+ }
+ return
+ }
+
+ // Verify an ARP response was sent.
+ pi, ok := c.linkEP.Read()
+ if !ok {
+ t.Fatal("expected ARP response to be sent, got none")
+ }
+
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
+ }
+ rep := header.ARP(pi.Pkt.NetworkHeader().View())
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want)
+ }
+
+ // Verify the sender was saved in the neighbor cache.
+ wantEvent := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ entry: stack.NeighborEntry{
+ Addr: test.senderAddr,
+ LinkAddr: tcpip.LinkAddress(test.senderLinkAddr),
+ State: stack.Stale,
+ },
+ }
+ if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
+ t.Fatal(err)
+ }
+
+ neighbors, err := c.s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("c.s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff)
+ }
+ t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ neigh, ok := neighborByAddr[test.senderAddr]
+ if !ok {
+ t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr)
+ }
+ if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
+ t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
+ }
+ if got, want := neigh.State, stack.Stale; got != want {
+ t.Errorf("got neighbor State = %s, want = %s", got, want)
+ }
+
+ // No more events should be dispatched
+ for {
+ event, ok := c.nudDisp.nextEvent()
+ if !ok {
+ break
+ }
+ t.Errorf("unexpected %s", event)
+ }
+ })
+ }
+}
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.LinkEndpoint
+
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
+func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
+ testAddr := tcpip.Address([]byte{1, 2, 3, 4})
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
+ remoteLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedLocalAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Unicast with no local address available",
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with no local address available",
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(arp.ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ }
+ }
+
+ // We pass a test network interface to LinkAddressRequest with the same
+ // NIC ID and link endpoint used by the NIC we created earlier so that we
+ // can mock a link address request and observe the packets sent to the
+ // link endpoint even though the stack uses the real NIC to validate the
+ // local address.
+ if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ }
+
+ rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index d1c728ccf..47fb63290 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -29,6 +29,8 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
],
)
@@ -41,5 +43,10 @@ go_test(
"reassembler_test.go",
],
library = ":fragmentation",
- deps = ["//pkg/tcpip/buffer"],
+ deps = [
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/network/testutil",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index f42abc4bb..936601287 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -13,32 +13,60 @@
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
-// It is based on RFC 791 and RFC 815.
+// It is based on RFC 791, RFC 815 and RFC 8200.
package fragmentation
import (
+ "errors"
"fmt"
"log"
"time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
-const DefaultReassembleTimeout = 30 * time.Second
+const (
+ // HighFragThreshold is the threshold at which we start trimming old
+ // fragmented packets. Linux uses a default value of 4 MB. See
+ // net.ipv4.ipfrag_high_thresh for more information.
+ HighFragThreshold = 4 << 20 // 4MB
-// HighFragThreshold is the threshold at which we start trimming old
-// fragmented packets. Linux uses a default value of 4 MB. See
-// net.ipv4.ipfrag_high_thresh for more information.
-const HighFragThreshold = 4 << 20 // 4MB
+ // LowFragThreshold is the threshold we reach to when we start dropping
+ // older fragmented packets. It's important that we keep enough room for newer
+ // packets to be re-assembled. Hence, this needs to be lower than
+ // HighFragThreshold enough. Linux uses a default value of 3 MB. See
+ // net.ipv4.ipfrag_low_thresh for more information.
+ LowFragThreshold = 3 << 20 // 3MB
-// LowFragThreshold is the threshold we reach to when we start dropping
-// older fragmented packets. It's important that we keep enough room for newer
-// packets to be re-assembled. Hence, this needs to be lower than
-// HighFragThreshold enough. Linux uses a default value of 3 MB. See
-// net.ipv4.ipfrag_low_thresh for more information.
-const LowFragThreshold = 3 << 20 // 3MB
+ // minBlockSize is the minimum block size for fragments.
+ minBlockSize = 1
+)
+
+var (
+ // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // provided.
+ ErrInvalidArgs = errors.New("invalid args")
+)
+
+// FragmentID is the identifier for a fragment.
+type FragmentID struct {
+ // Source is the source address of the fragment.
+ Source tcpip.Address
+
+ // Destination is the destination address of the fragment.
+ Destination tcpip.Address
+
+ // ID is the identification value of the fragment.
+ //
+ // This is a uint32 because IPv6 uses a 32-bit identification value.
+ ID uint32
+
+ // The protocol for the packet.
+ Protocol uint8
+}
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
@@ -46,14 +74,19 @@ type Fragmentation struct {
mu sync.Mutex
highLimit int
lowLimit int
- reassemblers map[uint32]*reassembler
+ reassemblers map[FragmentID]*reassembler
rList reassemblerList
size int
timeout time.Duration
+ blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
}
// NewFragmentation creates a new Fragmentation.
//
+// blockSize specifies the fragment block size, in bytes.
+//
// highMemoryLimit specifies the limit on the memory consumed
// by the fragments stored by Fragmentation (overhead of internal data-structures
// is not accounted). Fragments are dropped when the limit is reached.
@@ -64,7 +97,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -73,44 +106,100 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
lowMemoryLimit = 0
}
- return &Fragmentation{
- reassemblers: make(map[uint32]*reassembler),
+ if blockSize < minBlockSize {
+ blockSize = minBlockSize
+ }
+
+ f := &Fragmentation{
+ reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
+ blockSize: blockSize,
+ clock: clock,
}
+ f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
+
+ return f
}
-// Process processes an incoming fragment belonging to an ID
-// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+// Process processes an incoming fragment belonging to an ID and returns a
+// complete packet and its protocol number when all the packets belonging to
+// that ID have been received.
+//
+// [first, last] is the range of the fragment bytes.
+//
+// first must be a multiple of the block size f is configured with. The size
+// of the fragment data must be a multiple of the block size, unless there are
+// no fragments following this fragment (more set to false).
+//
+// proto is the protocol number marked in the fragment being processed. It has
+// to be given here outside of the FragmentID struct because IPv6 should not use
+// the protocol to identify a fragment.
+//
+// releaseCB is a callback that will run when the fragment reassembly of a
+// packet is complete or cancelled. releaseCB take a a boolean argument which is
+// true iff the reassembly is cancelled due to timeout. releaseCB should be
+// passed only with the first fragment of a packet. If more than one releaseCB
+// are passed for the same packet, only the first releaseCB will be saved for
+// the packet and the succeeding ones will be dropped by running them
+// immediately with a false argument.
+func (f *Fragmentation) Process(
+ id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) (
+ buffer.VectorisedView, uint8, bool, error) {
+ if first > last {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
+ }
+
+ if first%f.blockSize != 0 {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
+ }
+
+ fragmentSize := last - first + 1
+ if more && fragmentSize%f.blockSize != 0 {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
+ }
+
+ if l := vv.Size(); l < int(fragmentSize) {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ }
+ vv.CapLength(int(fragmentSize))
+
f.mu.Lock()
r, ok := f.reassemblers[id]
- if ok && r.tooOld(f.timeout) {
- // This is very likely to be an id-collision or someone performing a slow-rate attack.
- f.release(r)
- ok = false
- }
if !ok {
- r = newReassembler(id)
+ r = newReassembler(id, f.clock)
f.reassemblers[id] = r
+ wasEmpty := f.rList.Empty()
f.rList.PushFront(r)
+ if wasEmpty {
+ // If we have just pushed a first reassembler into an empty list, we
+ // should kickstart the release job. The release job will keep
+ // rescheduling itself until the list becomes empty.
+ f.releaseReassemblersLocked()
+ }
+ }
+ if releaseCB != nil {
+ if !r.setCallback(releaseCB) {
+ // We got a duplicate callback. Release it immediately.
+ releaseCB(false /* timedOut */)
+ }
}
f.mu.Unlock()
- res, done, consumed, err := r.process(first, last, more, vv)
+ res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
- f.release(r)
+ f.release(r, false /* timedOut */)
f.mu.Unlock()
- return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
f.size += consumed
if done {
- f.release(r)
+ f.release(r, false /* timedOut */)
}
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
@@ -120,14 +209,14 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
if tail == nil {
break
}
- f.release(tail)
+ f.release(tail, false /* timedOut */)
}
}
f.mu.Unlock()
- return res, done, nil
+ return res, firstFragmentProto, done, nil
}
-func (f *Fragmentation) release(r *reassembler) {
+func (f *Fragmentation) release(r *reassembler, timedOut bool) {
// Before releasing a fragment we need to check if r is already marked as done.
// Otherwise, we would delete it twice.
if r.checkDoneOrMark() {
@@ -141,4 +230,105 @@ func (f *Fragmentation) release(r *reassembler) {
log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
f.size = 0
}
+
+ r.release(timedOut) // releaseCB may run.
+}
+
+// releaseReassemblersLocked releases already-expired reassemblers, then
+// schedules the job to call back itself for the remaining reassemblers if
+// any. This function must be called with f.mu locked.
+func (f *Fragmentation) releaseReassemblersLocked() {
+ now := f.clock.NowMonotonic()
+ for {
+ // The reassembler at the end of the list is the oldest.
+ r := f.rList.Back()
+ if r == nil {
+ // The list is empty.
+ break
+ }
+ elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ if f.timeout > elapsed {
+ // If the oldest reassembler has not expired, schedule the release
+ // job so that this function is called back when it has expired.
+ f.releaseJob.Schedule(f.timeout - elapsed)
+ break
+ }
+ // If the oldest reassembler has already expired, release it.
+ f.release(r, true /* timedOut*/)
+ }
+}
+
+// PacketFragmenter is the book-keeping struct for packet fragmentation.
+type PacketFragmenter struct {
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ fragmentPayloadLen int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
+}
+
+// MakePacketFragmenter prepares the struct needed for packet fragmentation.
+//
+// pkt is the packet to be fragmented.
+//
+// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can
+// have.
+//
+// reserve is the number of bytes that should be reserved for the headers in
+// each generated fragment.
+func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter {
+ // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
+ // repeated in each fragment. However we do not currently support any header
+ // of that kind yet, so the following computation is valid for both IPv4 and
+ // IPv6.
+ // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
+ // supported for outbound packets, the fragmentable data should not include
+ // these headers.
+ var fragmentableData buffer.VectorisedView
+ fragmentableData.AppendView(pkt.TransportHeader().View())
+ fragmentableData.Append(pkt.Data)
+ fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
+
+ return PacketFragmenter{
+ data: fragmentableData,
+ reserve: reserve,
+ fragmentPayloadLen: int(fragmentPayloadLen),
+ fragmentCount: int(fragmentCount),
+ }
+}
+
+// BuildNextFragment returns a packet with the payload of the next fragment,
+// along with the fragment's offset, the number of bytes copied and a boolean
+// indicating if there are more fragments left or not. If this function is
+// called again after it indicated that no more fragments were left, it will
+// panic.
+//
+// Note that the returned packet will not have its network and link headers
+// populated, but space for them will be reserved. The transport header will be
+// stored in the packet's data.
+func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
+ if pf.currentFragment >= pf.fragmentCount {
+ panic("BuildNextFragment should not be called again after the last fragment was returned")
+ }
+
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: pf.reserve,
+ })
+
+ // Copy data for the fragment.
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen)
+
+ offset := pf.fragmentOffset
+ pf.fragmentOffset += copied
+ pf.currentFragment++
+ more := pf.currentFragment != pf.fragmentCount
+
+ return fragPkt, offset, copied, more
+}
+
+// RemainingFragmentCount returns the number of fragments left to be built.
+func (pf *PacketFragmenter) RemainingFragmentCount() int {
+ return pf.fragmentCount - pf.currentFragment
}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 72c0f53be..5dcd10730 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -15,13 +15,21 @@
package fragmentation
import (
+ "errors"
"reflect"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
)
+// reassembleTimeout is dummy timeout used for testing, where the clock never
+// advances.
+const reassembleTimeout = 1
+
// vv is a helper to build VectorisedView from different strings.
func vv(size int, pieces ...string) buffer.VectorisedView {
views := make([]buffer.View, len(pieces))
@@ -33,16 +41,18 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
}
type processInput struct {
- id uint32
+ id FragmentID
first uint16
last uint16
more bool
+ proto uint8
vv buffer.VectorisedView
}
type processOutput struct {
- vv buffer.VectorisedView
- done bool
+ vv buffer.VectorisedView
+ proto uint8
+ done bool
}
var processTestCases = []struct {
@@ -53,8 +63,8 @@ var processTestCases = []struct {
{
comment: "One ID",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -62,12 +72,23 @@ var processTestCases = []struct {
},
},
{
+ comment: "Next Header protocol mismatch",
+ in: []processInput{
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "01", "23"), proto: 6, done: true},
+ },
+ },
+ {
comment: "Two IDs",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -81,19 +102,27 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
+ firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil)
if err != nil {
- t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
+ in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
}
if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView())
}
if done != c.out[i].done {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
+ in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
}
if c.out[i].done {
+ if firstFragmentProto != proto {
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
+ }
if _, ok := f.reassemblers[in.id]; ok {
t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
}
@@ -109,53 +138,154 @@ func TestFragmentationProcess(t *testing.T) {
}
func TestReassemblingTimeout(t *testing.T) {
- timeout := time.Millisecond
- f := NewFragmentation(1024, 512, timeout)
- // Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(0, 0, 0, true, vv(1, "0"))
- // Sleep more than the timeout.
- time.Sleep(2 * timeout)
- // Send another fragment that completes a packet.
- // However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
- if err != nil {
- t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
- }
- if done {
- t.Errorf("Fragmentation does not respect the reassembling timeout.")
+ const (
+ reassemblyTimeout = time.Millisecond
+ protocol = 0xff
+ )
+
+ type fragment struct {
+ first uint16
+ last uint16
+ more bool
+ data string
+ }
+
+ type event struct {
+ // name is a nickname of this event.
+ name string
+
+ // clockAdvance is a duration to advance the clock. The clock advances
+ // before a fragment specified in the fragment field is processed.
+ clockAdvance time.Duration
+
+ // fragment is a fragment to process. This can be nil if there is no
+ // fragment to process.
+ fragment *fragment
+
+ // expectDone is true if the fragmentation instance should report the
+ // reassembly is done after the fragment is processd.
+ expectDone bool
+
+ // sizeAfterEvent is the expected size of the fragmentation instance after
+ // the event.
+ sizeAfterEvent int
+ }
+
+ half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
+ half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
+
+ tests := []struct {
+ name string
+ events []event
+ }{
+ {
+ name: "half1 and half2 are reassembled successfully",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: true,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ {
+ name: "half1 timeout, half2 timeout",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ for _, event := range test.events {
+ clock.Advance(event.clockAdvance)
+ if frag := event.fragment; frag != nil {
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil)
+ if err != nil {
+ t.Fatalf("%s: f.Process failed: %s", event.name, err)
+ }
+ if done != event.expectDone {
+ t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
+ }
+ }
+ if got, want := f.size, event.sizeAfterEvent; got != want {
+ t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want)
+ }
+ }
+ })
}
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send first fragment with id = 1.
- f.Process(1, 0, 0, true, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil)
// Send first fragment with id = 2.
- f.Process(2, 0, 0, true, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil)
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(3, 0, 0, true, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil)
- if _, ok := f.reassemblers[0]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
}
- if _, ok := f.reassemblers[1]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
}
- if _, ok := f.reassemblers[3]; !ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
}
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send the same packet again.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
got := f.size
want := 1
@@ -163,3 +293,293 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}
+
+func TestErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize uint16
+ first uint16
+ last uint16
+ more bool
+ data string
+ err error
+ }{
+ {
+ name: "exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: false,
+ data: "01",
+ },
+ {
+ name: "exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "01",
+ },
+ {
+ name: "exact block size with more and extra data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "012",
+ },
+ {
+ name: "exact block size with more and too little data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: false,
+ data: "0",
+ },
+ {
+ name: "first not a multiple of block size",
+ blockSize: 2,
+ first: 3,
+ last: 4,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "first more than last",
+ blockSize: 2,
+ first: 4,
+ last: 3,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil)
+ if !errors.Is(err, test.err) {
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
+ }
+ if done {
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data)
+ }
+ })
+ }
+}
+
+type fragmentInfo struct {
+ remaining int
+ copied int
+ offset int
+ more bool
+}
+
+func TestPacketFragmenter(t *testing.T) {
+ const (
+ reserve = 60
+ proto = 0
+ )
+
+ tests := []struct {
+ name string
+ fragmentPayloadLen uint32
+ transportHeaderLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ }{
+ {
+ name: "Packet exactly fits in MTU",
+ fragmentPayloadLen: 1280,
+ transportHeaderLen: 0,
+ payloadSize: 1280,
+ wantFragments: []fragmentInfo{
+ {remaining: 0, copied: 1280, offset: 0, more: false},
+ },
+ },
+ {
+ name: "Packet exactly does not fit in MTU",
+ fragmentPayloadLen: 1000,
+ transportHeaderLen: 0,
+ payloadSize: 1001,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 1000, offset: 0, more: true},
+ {remaining: 0, copied: 1, offset: 1000, more: false},
+ },
+ },
+ {
+ name: "Packet has a transport header",
+ fragmentPayloadLen: 560,
+ transportHeaderLen: 40,
+ payloadSize: 560,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 560, offset: 0, more: true},
+ {remaining: 0, copied: 40, offset: 560, more: false},
+ },
+ },
+ {
+ name: "Packet has a huge transport header",
+ fragmentPayloadLen: 500,
+ transportHeaderLen: 1300,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {remaining: 3, copied: 500, offset: 0, more: true},
+ {remaining: 2, copied: 500, offset: 500, more: true},
+ {remaining: 1, copied: 500, offset: 1000, more: true},
+ {remaining: 0, copied: 300, offset: 1500, more: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
+ var originalPayload buffer.VectorisedView
+ originalPayload.AppendView(pkt.TransportHeader().View())
+ originalPayload.Append(pkt.Data)
+ var reassembledPayload buffer.VectorisedView
+ pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
+ for i := 0; ; i++ {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ wantFragment := test.wantFragments[i]
+ if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
+ t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
+ }
+ if copied != wantFragment.copied {
+ t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
+ }
+ if offset != wantFragment.offset {
+ t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
+ }
+ if more != wantFragment.more {
+ t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
+ }
+ if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
+ }
+ if got := fragPkt.AvailableHeaderBytes(); got != reserve {
+ t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
+ }
+ if got := fragPkt.TransportHeader().View().Size(); got != 0 {
+ t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
+ }
+ reassembledPayload.Append(fragPkt.Data)
+ if !more {
+ if i != len(test.wantFragments)-1 {
+ t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
+ }
+ break
+ }
+ }
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestReleaseCallback(t *testing.T) {
+ const (
+ proto = 99
+ )
+
+ var result int
+ var callbackReasonIsTimeout bool
+ cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut }
+
+ tests := []struct {
+ name string
+ callbacks []func(bool)
+ timeout bool
+ wantResult int
+ wantCallbackReasonIsTimeout bool
+ }{
+ {
+ name: "callback runs on release",
+ callbacks: []func(bool){cb1},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "first callback is nil",
+ callbacks: []func(bool){nil, cb2},
+ timeout: false,
+ wantResult: 2,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "two callbacks - first one is set",
+ callbacks: []func(bool){cb1, cb2},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "callback runs on timeout",
+ callbacks: []func(bool){cb1},
+ timeout: true,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: true,
+ },
+ {
+ name: "no callbacks",
+ callbacks: []func(bool){nil},
+ timeout: false,
+ wantResult: 0,
+ wantCallbackReasonIsTimeout: false,
+ },
+ }
+
+ id := FragmentID{ID: 0}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ result = 0
+ callbackReasonIsTimeout = false
+
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+
+ for i, cb := range test.callbacks {
+ _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb)
+ if err != nil {
+ t.Errorf("f.Process error = %s", err)
+ }
+ }
+
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatalf("Reassemberr not found")
+ }
+ f.release(r, test.timeout)
+
+ if result != test.wantResult {
+ t.Errorf("got result = %d, want = %d", result, test.wantResult)
+ }
+ if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout {
+ t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 0a83d81f2..c0cc0bde0 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -32,23 +32,24 @@ type hole struct {
type reassembler struct {
reassemblerEntry
- id uint32
+ id FragmentID
size int
+ proto uint8
mu sync.Mutex
holes []hole
deleted int
heap fragHeap
done bool
- creationTime time.Time
+ creationTime int64
+ callback func(bool)
}
-func newReassembler(id uint32) *reassembler {
+func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
- deleted: 0,
heap: make(fragHeap, 0, 8),
- creationTime: time.Now(),
+ creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
@@ -78,7 +79,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -86,7 +87,18 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, consumed, nil
+ }
+ // For IPv6, it is possible to have different Protocol values between
+ // fragments of a packet (because, unlike IPv4, the Protocol is not used to
+ // identify a fragment). In this case, only the Protocol of the first
+ // fragment must be used as per RFC 8200 Section 4.5.
+ //
+ // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded
+ // here (instead of just the protocol) because most IP options should be
+ // derived from the first fragment.
+ if first == 0 {
+ r.proto = proto
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
@@ -96,17 +108,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.VectorisedView{}, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
+ return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
}
- return res, true, consumed, nil
-}
-
-func (r *reassembler) tooOld(timeout time.Duration) bool {
- return time.Now().Sub(r.creationTime) > timeout
+ return res, r.proto, true, consumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
@@ -116,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool {
r.mu.Unlock()
return prev
}
+
+func (r *reassembler) setCallback(c func(bool)) bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.callback != nil {
+ return false
+ }
+ r.callback = c
+ return true
+}
+
+func (r *reassembler) release(timedOut bool) {
+ r.mu.Lock()
+ callback := r.callback
+ r.callback = nil
+ r.mu.Unlock()
+
+ if callback != nil {
+ callback(timedOut)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index 7eee0710d..fa2a70dc8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -18,6 +18,8 @@ import (
"math"
"reflect"
"testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
type updateHolesInput struct {
@@ -94,7 +96,7 @@ var holesTestCases = []struct {
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
- r := newReassembler(0)
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
@@ -103,3 +105,26 @@ func TestUpdateHoles(t *testing.T) {
}
}
}
+
+func TestSetCallback(t *testing.T) {
+ result := 0
+ reasonTimeout := false
+
+ cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut }
+
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ if !r.setCallback(cb1) {
+ t.Errorf("setCallback failed")
+ }
+ if r.setCallback(cb2) {
+ t.Errorf("setCallback should fail if one is already set")
+ }
+ r.release(true)
+ if result != 1 {
+ t.Errorf("got result = %d, want = 1", result)
+ }
+ if !reasonTimeout {
+ t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout)
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4c20301c6..969579601 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,34 +15,48 @@
package ip_test
import (
+ "strings"
"testing"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
- localIpv4Addr = "\x0a\x00\x00\x01"
- localIpv4PrefixLen = 24
- remoteIpv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- localIpv6PrefixLen = 120
- remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ localIPv4Addr = "\x0a\x00\x00\x01"
+ remoteIPv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ nicID = 1
)
+var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv4Addr,
+ PrefixLen: 24,
+}
+
+var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv6Addr,
+ PrefixLen: 120,
+}
+
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
// The former is used to pretend that it's a link endpoint so that we can
// inspect packets written by the network endpoints. The latter is used to
@@ -96,15 +110,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
+ return stack.TransportPacketHandled
}
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
@@ -150,19 +165,19 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
if t.v4 {
- h := header.IPv4(pkt.Header.View())
+ h := header.IPv4(pkt.NetworkHeader().View())
prot = tcpip.TransportProtocolNumber(h.Protocol())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
} else {
- h := header.IPv6(pkt.Header.View())
+ h := header.IPv6(pkt.NetworkHeader().View())
prot = tcpip.TransportProtocolNumber(h.NextHeader())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
@@ -172,60 +187,345 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
-func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testObject) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv4.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv6.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStack() *stack.Stack {
- return stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
+ e := channel.New(0, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ }
+
+ v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ }
+
+ return s, e
+}
+
+func buildDummyStack(t *testing.T) *stack.Stack {
+ t.Helper()
+
+ s, _ := buildDummyStackWithLinkEndpoint(t)
+ return s
+}
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ testObject
+
+ mu struct {
+ sync.RWMutex
+ disabled bool
+ }
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (t *testInterface) Enabled() bool {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return !t.mu.disabled
+}
+
+func (t *testInterface) setEnabled(v bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mu.disabled = !v
+}
+
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func TestSourceAddressValidation(t *testing.T) {
+ rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv4Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ srcAddress tcpip.Address
+ rxICMP func(*channel.Endpoint, tcpip.Address)
+ valid bool
+ }{
+ {
+ name: "IPv4 valid",
+ srcAddress: "\x01\x02\x03\x04",
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 valid",
+ srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddress: "\xe0\x00\x00\x01",
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ rxICMP: rxIPv6ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 broadcast",
+ srcAddress: header.IPv4Broadcast,
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 subnet broadcast",
+ srcAddress: func() tcpip.Address {
+ subnet := localIPv4AddrWithPrefix.Subnet()
+ return subnet.Broadcast()
+ }(),
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := buildDummyStackWithLinkEndpoint(t)
+ test.rxICMP(e, test.srcAddress)
+
+ var wantValid uint64
+ if test.valid {
+ wantValid = 1
+ }
+
+ if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
+ t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
+ t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
+ }
+ })
+ }
+}
+
+func TestEnableWhenNICDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protocolFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protocolFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var nic testInterface
+ nic.setEnabled(false)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
+ })
+ p := s.NetworkProtocolInstance(test.protoNum)
+
+ // We pass nil for all parameters except the NetworkInterface and Stack
+ // since Enable only depends on these.
+ ep := p.NewEndpoint(&nic, nil, nil, nil)
+
+ // The endpoint should initially be disabled, regardless the NIC's enabled
+ // status.
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Attempting to enable the endpoint while the NIC is disabled should
+ // fail.
+ nic.setEnabled(false)
+ if err := ep.Enable(); err != tcpip.ErrNotPermitted {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ }
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status so we "enable" the NIC to read just the endpoint's
+ // enabled status.
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Enabling the interface after the NIC has been enabled should succeed.
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ if !ep.Enabled() {
+ t.Fatal("got ep.Enabled() = false, want = true")
+ }
+
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status.
+ nic.setEnabled(false)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Disabling the endpoint when the NIC is enabled should make the endpoint
+ // disabled.
+ nic.setEnabled(true)
+ ep.Disable()
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ })
+ }
}
func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
}
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
+ defer ep.Close()
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
@@ -233,33 +533,45 @@ func TestIPv4Send(t *testing.T) {
payload[i] = uint8(i)
}
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+ // Setup the packet buffer.
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ep.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ })
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv4Addr
+ nic.testObject.dstAddr = remoteIPv4Addr
+ nic.testObject.contents = payload
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- }); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
totalLen := header.IPv4MinimumSize + 30
@@ -270,9 +582,10 @@ func TestIPv4Receive(t *testing.T) {
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -280,20 +593,24 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: view.ToVectorisedView(),
})
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -303,7 +620,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
name string
expectedCount int
fragmentOffset uint16
- code uint8
+ code header.ICMPv4Code
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
@@ -317,20 +634,26 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
+ r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
if err != nil {
t.Fatal(err)
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
}
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
view := buffer.NewView(dataOffset + 8)
@@ -342,8 +665,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIpv4Addr,
+ DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Create the ICMP header.
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
@@ -360,41 +684,51 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: c.fragmentOffset,
- SrcAddr: localIpv4Addr,
- DstAddr: remoteIpv4Addr,
+ SrcAddr: localIPv4Addr,
+ DstAddr: remoteIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
}
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
+
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
-
- vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: vv,
- })
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
+
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
totalLen := header.IPv4MinimumSize + 24
@@ -408,9 +742,11 @@ func TestIPv4FragmentationReceive(t *testing.T) {
Protocol: 10,
FragmentOffset: 0,
Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
+ ip1.SetChecksum(^ip1.CalculateChecksum())
+
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag1[i] = uint8(i)
@@ -424,48 +760,65 @@ func TestIPv4FragmentationReceive(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: 24,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
+ ip2.SetChecksum(^ip2.CalculateChecksum())
+
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag2[i] = uint8(i)
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
// Send first segment.
- ep.HandlePacket(&r, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag1.ToVectorisedView(),
})
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.testObject.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
// Send second segment.
- ep.HandlePacket(&r, stack.PacketBuffer{
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag2.ToVectorisedView(),
})
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
// Allocate and initialize the payload view.
@@ -474,33 +827,44 @@ func TestIPv6Send(t *testing.T) {
payload[i] = uint8(i)
}
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+ // Setup the packet buffer.
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ep.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ })
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv6Addr
+ nic.testObject.dstAddr = remoteIPv6Addr
+ nic.testObject.contents = payload
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- }); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
totalLen := header.IPv6MinimumSize + 30
@@ -510,8 +874,8 @@ func TestIPv6Receive(t *testing.T) {
PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
NextHeader: 10,
HopLimit: 20,
- SrcAddr: remoteIpv6Addr,
- DstAddr: localIpv6Addr,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -520,21 +884,25 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: view.ToVectorisedView(),
})
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -548,7 +916,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
expectedCount int
fragmentOffset *uint16
typ header.ICMPv6Type
- code uint8
+ code header.ICMPv6Code
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
@@ -565,7 +933,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
r, err := buildIPv6Route(
- localIpv6Addr,
+ localIPv6Addr,
"\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
)
if err != nil {
@@ -573,15 +941,20 @@ func TestIPv6ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
}
-
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
@@ -595,7 +968,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
SrcAddr: outerSrcAddr,
- DstAddr: localIpv6Addr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -611,8 +984,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
PayloadLength: 100,
NextHeader: 10,
HopLimit: 20,
- SrcAddr: localIpv6Addr,
- DstAddr: remoteIpv6Addr,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
})
// Build the fragmentation header if needed.
@@ -634,21 +1007,435 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
// Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: view[:len(view)-c.trunc].ToVectorisedView(),
- })
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
+ }
+ })
+ }
+}
+
+// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
+// after truncation, is large enough to hold a network header, it makes part of
+// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
+// becomes Data.
+func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
+ v := view[:len(view)-trunc]
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v.ToVectorisedView(),
+ })
+ _, _ = pkt.NetworkHeader().Consume(netHdrLen)
+ return pkt
+}
+
+func TestWriteHeaderIncludedPacket(t *testing.T) {
+ const (
+ nicID = 1
+ transportProto = 5
+
+ dataLen = 4
+ optionsLen = 4
+ )
+
+ dataBuf := [dataLen]byte{1, 2, 3, 4}
+ data := dataBuf[:]
+
+ ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1}
+ ipv4Options := ipv4OptionsBuf[:]
+
+ ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
+ ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
+
+ var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
+ ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
+ if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ nicAddr tcpip.Address
+ remoteAddr tcpip.Address
+ pktGen func(*testing.T, tcpip.Address) buffer.View
+ checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
+ expectedErr *tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with IHL too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 1,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 minimum size",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(header.IPv4MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ totalLen := ipHdrLen + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(ipHdrLen))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHdrLen),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
+ }
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6 with extension header",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
+ checker.IPPayload(ipv6PayloadWithExtHdr),
+ )
+ },
+ },
+ {
+ name: "IPv6 minimum size",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(header.IPv6MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv6 too small",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ srcAddr tcpip.Address
+ }{
+ {
+ name: "unspecified source",
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ },
+ {
+ name: "random source",
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ e := channel.New(1, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
+
+ r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ }
+ defer r.Release()
+
+ if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ })); err != test.expectedErr {
+ t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a packet to be written")
+ }
+ test.checker(t, pkt.Pkt, subTest.srcAddr)
+ })
}
})
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 78420d6e6..6252614ec 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -10,9 +10,11 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
@@ -26,14 +28,19 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 4cbefe5ab..cf287446e 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,16 +15,20 @@
package ipv4
import (
+ "errors"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// handleControl handles the case when an ICMP packet contains the headers of
-// the original packet that caused the ICMP one to be sent. This information is
-// used to find out which transport endpoint must be notified about the ICMP
-// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+// handleControl handles the case when an ICMP error packet contains the headers
+// of the original packet that caused the ICMP one to be sent. This information
+// is used to find out which transport endpoint must be notified about the ICMP
+// packet. We only expect the payload, not the enclosing ICMP packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
return
@@ -37,8 +41,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// false.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
- // original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -53,12 +58,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
received.Invalid.Increment()
@@ -66,47 +74,142 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
}
h := header.ICMPv4(v)
+ // Only do in-stack processing if the checksum is correct.
+ if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
+ received.Invalid.Increment()
+ // It's possible that a raw socket expects to receive this regardless
+ // of checksum errors. If it's an echo request we know it's safe because
+ // we are the only handler, however other types do not cope well with
+ // packets with checksum errors.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ }
+ return
+ }
+
+ iph := header.IPv4(pkt.NetworkHeader().View())
+ var newOptions header.IPv4Options
+ if len(iph) > header.IPv4MinimumSize {
+ // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
+ // type ICMP packets):
+ // If a Record Route and/or Time Stamp option is received in an
+ // ICMP Echo Request, this option (these options) SHOULD be
+ // updated to include the current host and included in the IP
+ // header of the Echo Reply message, without "truncation".
+ // Thus, the recorded route will be for the entire round trip.
+ //
+ // So we need to let the option processor know how it should handle them.
+ var op optionsUsage
+ if h.Type() == header.ICMPv4Echo {
+ op = &optionUsageEcho{}
+ } else {
+ op = &optionUsageReceive{}
+ }
+ aux, tmp, err := processIPOptions(r, iph.Options(), op)
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ newOptions = tmp
+ }
+
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
received.Echo.Increment()
- // Only send a reply if the checksum is valid.
- wantChecksum := h.Checksum()
- // Reset the checksum field to 0 to can calculate the proper
- // checksum. We'll have to reset this before we hand the packet
- // off.
- h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- if gotChecksum != wantChecksum {
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- received.Invalid.Increment()
+ sent := stats.ICMP.V4PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
return
}
+ // DeliverTransportPacket will take ownership of pkt so don't use it beyond
+ // this point. Make a deep copy of the data before pkt gets sent as we will
+ // be modifying fields.
+ //
+ // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
+ // waiting endpoints. Consider moving responsibility for doing the copy to
+ // DeliverTransportPacket so that is is only done when needed.
+ replyData := pkt.Data.ToOwnedView()
+
// It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{
- Data: pkt.Data.Clone(nil),
- NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ pkt = nil
+ // Take the base of the incoming request IP header but replace the options.
+ replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions))
+ replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...))
+ replyIPHdr.SetHeaderLength(replyHeaderLength)
+
+ // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
+ // source address MUST be one of its own IP addresses (but not a broadcast
+ // or multicast address).
+ localAddr := r.LocalAddress
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) {
+ localAddr = ""
+ }
+
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
+ // header information, we may have to change this code to handle the
+ // ICMP header no longer being in the data buffer.
+
+ // Because IP and ICMP are so closely intertwined, we need to handcraft our
+ // IP header to be able to follow RFC 792. The wording on page 13 is as
+ // follows:
+ // IP Fields:
+ // Addresses
+ // The address of the source in an echo message will be the
+ // destination of the echo reply message. To form an echo reply
+ // message, the source and destination addresses are simply reversed,
+ // the type code changed to 0, and the checksum recomputed.
+ //
+ // This was interpreted by early implementors to mean that all options must
+ // be copied from the echo request IP header to the echo reply IP header
+ // and this behaviour is still relied upon by some applications.
+ //
+ // Create a copy of the IP header we received, options and all, and change
+ // The fields we need to alter.
+ //
+ // We need to produce the entire packet in the data segment in order to
+ // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the
+ // total length and the header checksum so we don't need to set those here.
+ replyIPHdr.SetSourceAddress(r.LocalAddress)
+ replyIPHdr.SetDestinationAddress(r.RemoteAddress)
+ replyIPHdr.SetTTL(r.DefaultTTL())
+
+ replyICMPHdr := header.ICMPv4(replyData)
+ replyICMPHdr.SetType(header.ICMPv4EchoReply)
+ replyICMPHdr.SetChecksum(0)
+ replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0))
+
+ replyVV := buffer.View(replyIPHdr).ToVectorisedView()
+ replyVV.AppendView(replyData)
+ replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: replyVV,
})
+ replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- vv := pkt.Data.Clone(nil)
- vv.TrimFront(header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- copy(pkt, h)
- pkt.SetType(header.ICMPv4EchoReply)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
- sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: vv,
- TransportHeader: buffer.View(pkt),
- }); err != nil {
+ if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -122,12 +225,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(h.MTU())
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
}
case header.ICMPv4SrcQuench:
@@ -158,3 +267,217 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
received.Invalid.Increment()
}
}
+
+// ======= ICMP Error packet generation =========
+
+// icmpReason is a marker interface for IPv4 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// icmpReasonProtoUnreachable is an error where the transport protocol is
+// not supported.
+type icmpReasonProtoUnreachable struct{}
+
+func (*icmpReasonProtoUnreachable) isICMPReason() {}
+
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
+// icmpReasonParamProblem is an error to use to request a Parameter Problem
+// message to be sent.
+type icmpReasonParamProblem struct {
+ pointer byte
+}
+
+func (*icmpReasonParamProblem) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv4 and sends it back to the remote device that sent
+// the problematic packet. It incorporates as much of that packet as
+// possible as well as any error metadata as is available. returnError
+// expects pkt to hold a valid IPv4 packet as per the wire format.
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ // We check we are responding only when we are allowed to.
+ // See RFC 1812 section 4.3.2.7 (shown below).
+ //
+ // =========
+ // 4.3.2.7 When Not to Send ICMP Errors
+ //
+ // An ICMP error message MUST NOT be sent as the result of receiving:
+ //
+ // o An ICMP error message, or
+ //
+ // o A packet which fails the IP header validation tests described in
+ // Section [5.2.2] (except where that section specifically permits
+ // the sending of an ICMP error message), or
+ //
+ // o A packet destined to an IP broadcast or IP multicast address, or
+ //
+ // o A packet sent as a Link Layer broadcast or multicast, or
+ //
+ // o Any fragment of a datagram other then the first fragment (i.e., a
+ // packet for which the fragment offset in the IP header is nonzero).
+ //
+ // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
+ // response to a non-initial fragment, but it currently can not happen.
+
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any {
+ return nil
+ }
+
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ sent := p.stack.Stats().ICMP.V4PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ networkHeader := pkt.NetworkHeader().View()
+ transportHeader := pkt.TransportHeader().View()
+
+ // Don't respond to icmp error packets.
+ if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) {
+ // TODO(gvisor.dev/issue/3810):
+ // Unfortunately the current stack pretty much always has ICMPv4 headers
+ // in the Data section of the packet but there is no guarantee that is the
+ // case. If this is the case grab the header to make it like all other
+ // packet types. When this is cleaned up the Consume should be removed.
+ if transportHeader.IsEmpty() {
+ var ok bool
+ transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize)
+ if !ok {
+ return nil
+ }
+ } else if transportHeader.Size() < header.ICMPv4MinimumSize {
+ return nil
+ }
+ // We need to decide to explicitly name the packets we can respond to or
+ // the ones we can not respond to. The decision is somewhat arbitrary and
+ // if problems arise this could be reversed. It was judged less of a breach
+ // of protocol to not respond to unknown non-error packets than to respond
+ // to unknown error packets so we take the first approach.
+ switch header.ICMPv4(transportHeader).Type() {
+ case
+ header.ICMPv4EchoReply,
+ header.ICMPv4Echo,
+ header.ICMPv4Timestamp,
+ header.ICMPv4TimestampReply,
+ header.ICMPv4InfoRequest,
+ header.ICMPv4InfoReply:
+ default:
+ // Assume any type we don't know about may be an error type.
+ return nil
+ }
+ }
+
+ // Now work out how much of the triggering packet we should return.
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes.
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 and RFC 792 where it mentioned that at
+ // least 8 bytes of the payload must be included. Today linux and other
+ // systems implement the RFC 1812 definition and not the original
+ // requirement. We treat 8 bytes as the minimum but will try send more.
+ mtu := int(route.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+
+ if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
+ return nil
+ }
+
+ payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, an AF_RAW or AF_PACKET socket may use what the transport
+ // protocol considers an unreachable destination. Thus we deep copy pkt to
+ // prevent multiple ownership and SR errors. The new copy is a vectorized
+ // view with the entire incoming IP packet reassembled and truncated as
+ // required. This is now the payload of the new ICMP packet and no longer
+ // considered a packet in its own right.
+ newHeader := append(buffer.View(nil), networkHeader...)
+ newHeader = append(newHeader, transportHeader...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
+ payload.CapLength(payloadLen)
+
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+
+ icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
+
+ icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ counter = sent.DstUnreachable
+ case *icmpReasonProtoUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
+ counter = sent.TimeExceeded
+ case *icmpReasonParamProblem:
+ icmpHdr.SetType(header.ICMPv4ParamProblem)
+ icmpHdr.SetCode(header.ICMPv4UnusedCode)
+ icmpHdr.SetPointer(reason.pointer)
+ counter = sent.ParamProblem
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+
+ if err := route.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ icmpPkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 64046cbbf..4592984a5 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -12,26 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ipv4 contains the implementation of the ipv4 network protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.NewProtocol() as one of the network
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// ipv4.ProtocolNumber as the network protocol number when calling
-// Stack.NewEndpoint().
+// Package ipv4 contains the implementation of the ipv4 network protocol.
package ipv4
import (
+ "errors"
+ "fmt"
+ "math"
"sync/atomic"
+ "time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // ReassembleTimeout is the time a packet stays in the reassembly
+ // system before being evicted.
+ // As per RFC 791 section 3.2:
+ // The current recommendation for the initial timer setting is 15 seconds.
+ // This may be changed as experience with this protocol accumulates.
+ //
+ // Considering that it is an old recommendation, we use the same reassembly
+ // timeout that linux defines, which is 30 seconds:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
+ ReassembleTimeout = 30 * time.Second
+
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -44,78 +56,141 @@ const (
// buckets is the number of identifier buckets.
buckets = 2048
+
+ // The size of a fragment block, in bytes, as per RFC 791 section 3.1,
+ // page 14.
+ fragmentblockSize = 8
)
+var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
+
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
- linkEP stack.LinkEndpoint
- dispatcher stack.TransportDispatcher
- fragmentation *fragmentation.Fragmentation
- protocol *protocol
- stack *stack.Stack
+ nic stack.NetworkInterface
+ dispatcher stack.TransportDispatcher
+ protocol *protocol
+
+ // enabled is set to 1 when the enpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ }
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
- nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
- dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
- protocol: p,
- stack: st,
+ nic: nic,
+ dispatcher: dispatcher,
+ protocol: p,
+ }
+ e.mu.addressableEndpointState.Init(e)
+ return e
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
}
- return e, nil
+ // Create an endpoint to receive broadcast packets on this interface.
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ return err
+ }
+ // We have no need for the address endpoint.
+ ep.DecRef()
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
+ return err
}
-// DefaultTTL is the default time-to-live value for this endpoint.
-func (e *endpoint) DefaultTTL() uint8 {
- return e.protocol.DefaultTTL()
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
-func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
}
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
+ }
+
+ // The address may have already been removed.
+ if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
+ }
}
-// ID returns the ipv4 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
+// DefaultTTL is the default time-to-live value for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
}
-// PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize)
+ if err != nil {
+ return 0
+ }
+ return networkMTU
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -123,113 +198,13 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is entirely in pkt.Header but does not
-// assume that only the IP header is in pkt.Header. It assumes that the input
-// packet's stated length matches the length of the header+payload. mtu
-// includes the IP header and options. This does not support the DontFragment
-// IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error {
- // This packet is too big, it needs to be fragmented.
- ip := header.IPv4(pkt.Header.View())
- flags := ip.Flags()
-
- // Update mtu to take into account the header, which will exist in all
- // fragments anyway.
- innerMTU := mtu - int(ip.HeaderLength())
-
- // Round the MTU down to align to 8 bytes. Then calculate the number of
- // fragments. Calculate fragment sizes as in RFC791.
- innerMTU &^= 7
- n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
-
- outerMTU := innerMTU + int(ip.HeaderLength())
- offset := ip.FragmentOffset()
- originalAvailableLength := pkt.Header.AvailableLength()
- for i := 0; i < n; i++ {
- // Where possible, the first fragment that is sent has the same
- // pkt.Header.UsedLength() as the input packet. The link-layer
- // endpoint may depend on this for looking at, eg, L4 headers.
- h := ip
- if i > 0 {
- pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
- h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength())))
- copy(h, ip[:ip.HeaderLength()])
- }
- if i != n-1 {
- h.SetTotalLength(uint16(outerMTU))
- h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
- } else {
- h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size()))
- h.SetFlagsFragmentOffset(flags, offset)
- }
- h.SetChecksum(0)
- h.SetChecksum(^h.CalculateChecksum())
- offset += uint16(innerMTU)
- if i > 0 {
- newPayload := pkt.Data.Clone(nil)
- newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
- Header: pkt.Header,
- Data: newPayload,
- NetworkHeader: buffer.View(h),
- }); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- pkt.Data.TrimFront(newPayload.Size())
- continue
- }
- // Special handling for the first fragment because it comes
- // from the header.
- if outerMTU >= pkt.Header.UsedLength() {
- // This fragment can fit all of pkt.Header and possibly
- // some of pkt.Data, too.
- newPayload := pkt.Data.Clone(nil)
- newPayloadLength := outerMTU - pkt.Header.UsedLength()
- newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
- Header: pkt.Header,
- Data: newPayload,
- NetworkHeader: buffer.View(h),
- }); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- pkt.Data.TrimFront(newPayloadLength)
- } else {
- // The fragment is too small to fit all of pkt.Header.
- startOfHdr := pkt.Header
- startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
- emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{
- Header: startOfHdr,
- Data: emptyVV,
- NetworkHeader: buffer.View(h),
- }); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- // Add the unused bytes of pkt.Header into the pkt.Data
- // that remains to be sent.
- restOfHdr := pkt.Header.View()[outerMTU:]
- tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
- tmp.Append(pkt.Data)
- pkt.Data = tmp
- }
- }
- return nil
-}
-
-func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + payloadSize)
- id := uint32(0)
- if length > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
- }
+func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+ ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+ length := uint16(pkt.Size())
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
+ // datagrams. Since the DF bit is never being set here, all datagrams
+ // are non-atomic and need an ID.
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
@@ -241,64 +216,96 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- return ip
+ pkt.NetworkProtocolNumber = ProtocolNumber
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ // Round the MTU down to align to 8 bytes.
+ fragmentPayloadSize := networkMTU &^ 7
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader))
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ return n, pf.RemainingFragmentCount(), nil
+ }
+ }
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
- ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
- pkt.NetworkHeader = buffer.View(ip)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt)
+}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok {
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
- netHeader := header.IPv4(pkt.NetworkHeader)
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
-
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- packet := stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
- ep.HandlePacket(&route, packet)
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- // The inbound path expects the network header to still be in
- // the PacketBuffer's Data field.
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
-
- e.HandlePacket(&loopedR, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
-
+ e.HandlePacket(&loopedR, pkt)
loopedR.Release()
}
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -314,26 +321,49 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
- for pkt := pkts.Front(); pkt != nil; {
- ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
- pkt.NetworkHeader = buffer.View(ip)
- pkt = pkt.Next()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.addIPHeader(r, pkt, params)
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ // Keep track of the packet that is about to be fragmented so it can be
+ // removed once the fragmentation is done.
+ originalPkt := pkt
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(pkt, fragPkt)
+ pkt = fragPkt
+ return nil
+ }); err != nil {
+ panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err))
+ }
+ // Remove the packet that was just fragmented and process the rest.
+ pkts.Remove(originalPkt)
+ }
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- // Slow Path as we are dropping some packets in the batch degrade to
+ // Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
@@ -341,120 +371,139 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv4(pkt.NetworkHeader)
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
- if err == nil {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
-
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- packet := stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
- ep.HandlePacket(&route, packet)
+ ep.HandlePacket(&route, pkt)
n++
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
}
n++
}
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, nil
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
-// WriteHeaderIncludedPacket writes a packet already containing a network
-// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrInvalidOptionValue
+ return tcpip.ErrMalformedHeader
}
ip := header.IPv4(h)
- if !ip.IsValid(pkt.Data.Size()) {
- return tcpip.ErrInvalidOptionValue
- }
// Always set the total length.
- ip.SetTotalLength(uint16(pkt.Data.Size()))
+ pktSize := pkt.Data.Size()
+ ip.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
- if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ if ip.SourceAddress() == header.IPv4Any {
ip.SetSourceAddress(r.LocalAddress)
}
- // Set the destination. If the packet already included a destination,
- // it will be part of the route.
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
ip.SetDestinationAddress(r.RemoteAddress)
// Set the packet ID when zero.
if ip.ID() == 0 {
- id := uint32(0)
- if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for
+ // non-atomic datagrams, so assign an ID to all such datagrams
+ // according to the definition given in RFC 6864 section 4.
+ if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
- ip.SetID(uint16(id))
}
// Always set the checksum.
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if r.Loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, pkt.Clone())
- }
- if r.Loop&stack.PacketOut == 0 {
- return nil
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
}
- r.Stats().IP.PacketsSent.Increment()
-
- ip = ip[:ip.HeaderLength()]
- pkt.Header = buffer.NewPrependableFromView(buffer.View(ip))
- pkt.Data.TrimFront(int(ip.HeaderLength()))
- return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return e.writePacket(r, nil /* gso */, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
+ h := header.IPv4(pkt.NetworkHeader().View())
+ if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- h := header.IPv4(headerView)
- if !h.IsValid(pkt.Data.Size()) {
+
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if h.CalculateChecksum() != 0xffff {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- pkt.NetworkHeader = headerView[:h.HeaderLength()]
- hlen := int(h.HeaderLength())
- tlen := int(h.TotalLength())
- pkt.Data.TrimFront(hlen)
- pkt.Data.CapLength(tlen - hlen)
+ // As per RFC 1122 section 3.2.1.3:
+ // When a host sends any datagram, the IP source address MUST
+ // be one of its own IP addresses (but not a broadcast or
+ // multicast address).
+ if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok {
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
return
}
- more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
- if more || h.FragmentOffset() != 0 {
- if pkt.Data.Size() == 0 {
+ if h.More() || h.FragmentOffset() != 0 {
+ if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -462,18 +511,60 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
return
}
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < h.FragmentOffset() {
+ start := h.FragmentOffset()
+ // Drop the fragment if the size of the reassembled payload would exceed the
+ // maximum payload size.
+ //
+ // Note that this addition doesn't overflow even on 32bit architecture
+ // because pkt.Data.Size() should not exceed 65535 (the max IP datagram
+ // size). Otherwise the packet would've been rejected as invalid before
+ // reaching here.
+ if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
+
+ // Set up a callback in case we need to send a Time Exceeded Message, as per
+ // RFC 792:
+ //
+ // If a host reassembling a fragmented datagram cannot complete the
+ // reassembly due to missing fragments within its time limit it discards
+ // the datagram, and it may send a time exceeded message.
+ //
+ // If fragment zero is not available then no time exceeded need be sent at
+ // all.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ r := r.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt)
+ }
+ r.Release()
+ }
+ }
+
var ready bool
var err error
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
+ proto := h.Protocol()
+ pkt.Data, _, ready, err = e.protocol.fragmentation.Process(
+ // As per RFC 791 section 2.3, the identification value is unique
+ // for a source-destination pair and protocol.
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: uint32(h.ID()),
+ Protocol: proto,
+ },
+ start,
+ start+uint16(pkt.Data.Size())-1,
+ h.More(),
+ proto,
+ pkt.Data,
+ releaseCB,
+ )
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -482,28 +573,193 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
if !ready {
return
}
+
+ // The reassembler doesn't take care of fixing up the header, so we need
+ // to do it here.
+ h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
+ h.SetFlagsFragmentOffset(0, 0)
}
+ r.Stats().IP.PacketsDelivered.Increment()
+
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
- headerView.CapLength(hlen)
+ // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
+ // headers, the setting of the transport number here should be
+ // unnecessary and removed.
+ pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt)
return
}
- r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ if len(h.Options()) != 0 {
+ // TODO(gvisor.dev/issue/4586):
+ // When we add forwarding support we should use the verified options
+ // rather than just throwing them away.
+ aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{})
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ }
+
+ switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ case stack.TransportPacketHandled:
+ case stack.TransportPacketDestinationPortUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1
+ // A host SHOULD generate Destination Unreachable messages with code:
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported
+ _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
+ }
}
// Close cleans up resources associated with the endpoint.
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.disableLocked()
+ e.mu.addressableEndpointState.Cleanup()
+}
+
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ loopback := e.nic.IsLoopback()
+ addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ // IPv4 has a notion of a subnet broadcast address and considers the
+ // loopback interface bound to an address's whole subnet (on linux).
+ return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
+ })
+ if addressEndpoint != nil {
+ return addressEndpoint
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
+ if err != nil {
+ // AddAddress only returns an error if the address is already assigned,
+ // but we just checked above if the address exists so we expect no error.
+ panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
+ }
+ return addressEndpoint
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV4MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
- ids []uint32
- hashIV uint32
+ stack *stack.Stack
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
defaultTTL uint32
+
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
+ ids []uint32
+ hashIV uint32
+
+ fragmentation *fragmentation.Fragmentation
}
// Number returns the ipv4 protocol number.
@@ -528,10 +784,10 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case tcpip.DefaultTTLOption:
- p.SetDefaultTTL(uint8(v))
+ case *tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(*v))
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -539,7 +795,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
@@ -565,28 +821,80 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- if mtu > MaxTotalSize {
- mtu = MaxTotalSize
+// Parse implements stack.NetworkProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ if ok := parse.IPv4(pkt); !ok {
+ return 0, false, false
+ }
+
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
+}
+
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ if v {
+ atomic.StoreUint32(&p.forwarding, 1)
+ } else {
+ atomic.StoreUint32(&p.forwarding, 0)
+ }
+}
+
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload mtu.
+func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv4MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in
+ // length:
+ // The maximal internet header is 60 octets, and a typical internet header
+ // is 20 octets, allowing a margin for headers of higher level protocols.
+ if networkHeaderSize > header.IPv4MaximumHeaderSize {
+ return 0, tcpip.ErrMalformedHeader
+ }
+
+ networkMTU := linkMTU
+ if networkMTU > MaxTotalSize {
+ networkMTU = MaxTotalSize
}
- return mtu - header.IPv4MinimumSize
+
+ return networkMTU - uint32(networkHeaderSize), nil
+}
+
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
+}
+
+// addressToUint32 translates an IPv4 address into its little endian uint32
+// representation.
+//
+// This function does the same thing as binary.LittleEndian.Uint32 but operates
+// on a tcpip.Address (a string) without the need to convert it to a byte slice,
+// which would cause an allocation.
+func addressToUint32(addr tcpip.Address) uint32 {
+ _ = addr[3] // bounds check hint to compiler
+ return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
}
// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number, and a random initial
-// value (generated once on initialization) to generate the hash.
+// destination address, the transport protocol number and a 32-bit number to
+// generate the hash.
func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- t := r.LocalAddress
- a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- t = r.RemoteAddress
- b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ a := addressToUint32(r.LocalAddress)
+ b := addressToUint32(r.RemoteAddress)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol() stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -596,5 +904,353 @@ func NewProtocol() stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
+ return &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
+ }
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeaderLength := len(originalIPHeader)
+ nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+
+ if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
+ }
+
+ flags := originalIPHeader.Flags()
+ if more {
+ flags |= header.IPv4FlagMoreFragments
+ }
+ nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset))
+ nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied))
+ nextFragIPHeader.SetChecksum(0)
+ nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum())
+
+ return fragPkt, more
+}
+
+// optionAction describes possible actions that may be taken on an option
+// while processing it.
+type optionAction uint8
+
+const (
+ // optionRemove says that the option should not be in the output option set.
+ optionRemove optionAction = iota
+
+ // optionProcess says that the option should be fully processed.
+ optionProcess
+
+ // optionVerify says the option should be checked and passed unchanged.
+ optionVerify
+
+ // optionPass says to pass the output set without checking.
+ optionPass
+)
+
+// optionActions list what to do for each option in a given scenario.
+type optionActions struct {
+ // timestamp controls what to do with a Timestamp option.
+ timestamp optionAction
+
+ // recordroute controls what to do with a Record Route option.
+ recordRoute optionAction
+
+ // unknown controls what to do with an unknown option.
+ unknown optionAction
+}
+
+// optionsUsage specifies the ways options may be operated upon for a given
+// scenario during packet processing.
+type optionsUsage interface {
+ actions() optionActions
+}
+
+// optionUsageReceive implements optionsUsage for received packets.
+type optionUsageReceive struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageReceive) actions() optionActions {
+ return optionActions{
+ timestamp: optionVerify,
+ recordRoute: optionVerify,
+ unknown: optionPass,
+ }
+}
+
+// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it
+// is enabled (Process, Process, Pass) and for fragmenting (Process, Process,
+// Pass for frag1, but Remove,Remove,Remove for all other frags).
+
+// optionUsageEcho implements optionsUsage for echo packet processing.
+type optionUsageEcho struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageEcho) actions() optionActions {
+ return optionActions{
+ timestamp: optionProcess,
+ recordRoute: optionProcess,
+ unknown: optionRemove,
+ }
+}
+
+var (
+ errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length")
+ errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer")
+ errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp")
+ errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags")
+)
+
+// handleTimestamp does any required processing on a Timestamp option
+// in place.
+func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) {
+ flags := tsOpt.Flags()
+ var entrySize uint8
+ switch flags {
+ case header.IPv4OptionTimestampOnlyFlag:
+ entrySize = header.IPv4OptionTimestampSize
+ case
+ header.IPv4OptionTimestampWithIPFlag,
+ header.IPv4OptionTimestampWithPredefinedIPFlag:
+ entrySize = header.IPv4OptionTimestampWithAddrSize
+ default:
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags
+ }
+
+ pointer := tsOpt.Pointer()
+ // To simplify processing below, base further work on the array of timestamps
+ // beyond the header, rather than on the whole option. Also to aid
+ // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
+ nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1)
+ optLen := tsOpt.Size()
+ dataLength := optLen - header.IPv4OptionTimestampHdrLength
+
+ // In the section below, we verify the pointer, length and overflow counter
+ // fields of the option. The distinction is in which byte you return as being
+ // in error in the ICMP packet. Offsets 1 (length), 2 pointer)
+ // or 3 (overflowed counter).
+ //
+ // The following RFC sections cover this section:
+ //
+ // RFC 791 (page 22):
+ // If there is some room but not enough room for a full timestamp
+ // to be inserted, or the overflow count itself overflows, the
+ // original datagram is considered to be in error and is discarded.
+ // In either case an ICMP parameter problem message may be sent to
+ // the source host [3].
+ //
+ // You can get this situation in two ways. Firstly if the data area is not
+ // a multiple of the entry size or secondly, if the pointer is not at a
+ // multiple of the entry size. The wording of the RFC suggests that
+ // this is not an error until you actually run out of space.
+ if pointer > optLen {
+ // RFC 791 (page 22) says we should switch to using the overflow count.
+ // If the timestamp data area is already full (the pointer exceeds
+ // the length) the datagram is forwarded without inserting the
+ // timestamp, but the overflow count is incremented by one.
+ if flags == header.IPv4OptionTimestampWithPredefinedIPFlag {
+ // By definition we have nothing to do.
+ return 0, nil
+ }
+
+ if tsOpt.IncOverflow() != 0 {
+ return 0, nil
+ }
+ // The overflow count is also full.
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow
+ }
+ if nextSlot+entrySize > dataLength {
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad.
+ if false {
+ // We must select Pointer for Linux compatibility, even if
+ // only the length is bad.
+ // The Linux code is at (in October 2020)
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ // which doesn't distinguish between which of optptr[2] or optlen
+ // is wrong, but just arbitrarily decides on optptr+2.
+ if dataLength%entrySize != 0 {
+ // The Data section size should be a multiple of the expected
+ // timestamp entry size.
+ return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength
+ }
+ // If the size is OK, the pointer must be corrupted.
+ }
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
+
+ if usage.actions().timestamp == optionProcess {
+ tsOpt.UpdateTimestamp(localAddress, clock)
+ }
+ return 0, nil
+}
+
+var (
+ errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route")
+ errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route")
+)
+
+// handleRecordRoute checks and processes a Record route option. It is much
+// like the timestamp type 1 option, but without timestamps. The passed in
+// address is stored in the option in the correct spot if possible.
+func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) {
+ optlen := rrOpt.Size()
+
+ if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+
+ nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+
+ // RFC 791 page 21 says
+ // If the route data area is already full (the pointer exceeds the
+ // length) the datagram is forwarded without inserting the address
+ // into the recorded route. If there is some room but not enough
+ // room for a full address to be inserted, the original datagram is
+ // considered to be in error and is discarded. In either case an
+ // ICMP parameter problem message may be sent to the source
+ // host.
+ // The use of the words "In either case" suggests that a 'full' RR option
+ // could generate an ICMP at every hop after it fills up. We chose to not
+ // do this (as do most implementations). It is probable that the inclusion
+ // of these words is a copy/paste error from the timestamp option where
+ // there are two failure reasons given.
+ if nextSlot >= optlen {
+ return 0, nil
+ }
+
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad. We must select Pointer for Linux
+ // compatibility, even if only the length is bad.
+ if nextSlot+header.IPv4AddressSize > optlen {
+ if false {
+ // This is what we would do if we were not being Linux compatible.
+ // Check for bad pointer or length value. Must be a multiple of 4 after
+ // accounting for the 3 byte header and not within that header.
+ // RFC 791, page 20 says:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ //
+ // A recorded route is composed of a series of internet addresses.
+ // Each internet address is 32 bits or 4 octets.
+ // Linux skips this test so we must too. See Linux code at:
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 {
+ // Length is bad, not on integral number of slots.
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+ // If not length, the fault must be with the pointer.
+ }
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
+ if usage.actions().recordRoute == optionVerify {
+ return 0, nil
+ }
+ rrOpt.StoreAddress(localAddress)
+ return 0, nil
+}
+
+// processIPOptions parses the IPv4 options and produces a new set of options
+// suitable for use in the next step of packet processing as informed by usage.
+// The original will not be touched.
+//
+// Returns
+// - The location of an error if there was one (or 0 if no error)
+// - If there is an error, information as to what it was was.
+// - The replacement option set.
+func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
+
+ opts := header.IPv4Options(orig)
+ optIter := opts.MakeIterator()
+
+ // Each option other than NOP must only appear (RFC 791 section 3.1, at the
+ // definition of every type). Keep track of each of the possible types in
+ // the 8 bit 'type' field.
+ var seenOptions [math.MaxUint8 + 1]bool
+
+ // TODO(gvisor.dev/issue/4586):
+ // This will need tweaking when we start really forwarding packets
+ // as we may need to get two addresses, for rx and tx interfaces.
+ // We will also have to take usage into account.
+ prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber)
+ localAddress := prefixedAddress.Address
+ if err != nil {
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) {
+ return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
+ }
+ localAddress = r.LocalAddress
+ }
+
+ for {
+ option, done, err := optIter.Next()
+ if done || err != nil {
+ return optIter.ErrCursor, optIter.Finalize(), err
+ }
+ optType := option.Type()
+ if optType == header.IPv4OptionNOPType {
+ optIter.PushNOPOrEnd(optType)
+ continue
+ }
+ if optType == header.IPv4OptionListEndType {
+ optIter.PushNOPOrEnd(optType)
+ return 0 /* errCursor */, optIter.Finalize(), nil /* err */
+ }
+
+ // check for repeating options (multiple NOPs are OK)
+ if seenOptions[optType] {
+ return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate
+ }
+ seenOptions[optType] = true
+
+ optLen := int(option.Size())
+ switch option := option.(type) {
+ case *header.IPv4OptionTimestamp:
+ r.Stats().IP.OptionTSReceived.Increment()
+ if usage.actions().timestamp != optionRemove {
+ clock := r.Stack().Clock()
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ case *header.IPv4OptionRecordRoute:
+ r.Stats().IP.OptionRRReceived.Increment()
+ if usage.actions().recordRoute != optionRemove {
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ default:
+ r.Stats().IP.OptionUnknownReceived.Increment()
+ if usage.actions().unknown == optionPass {
+ newBuffer := optIter.RemainingBuffer()[:optLen]
+ // Arguments already heavily checked.. ignore result.
+ _ = copy(newBuffer, option.Contents())
+ optIter.ConsumeBuffer(optLen)
+ }
+ }
+ }
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 36035c820..61672a5ff 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,31 +15,43 @@
package ipv4_test
import (
- "bytes"
+ "context"
"encoding/hex"
- "math/rand"
+ "fmt"
+ "math"
+ "net"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
+const (
+ extraHeaderReserve = 50
+ defaultMTU = 65536
+)
+
func TestExcludeBroadcast(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- const defaultMTU = 65536
ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
ep = sniffer.New(ep)
@@ -91,35 +103,672 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
- hdr := buffer.NewPrependable(hdrLength + extraLength)
- hdr.Prepend(hdrLength)
- rand.Read(hdr.View())
-
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
- }
- payload := buffer.NewVectorisedView(totalLength, views)
- return hdr, payload
+// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
+// checks the response.
+func TestIPv4Sanity(t *testing.T) {
+ const (
+ ttl = 255
+ nicID = 1
+ randomSequence = 123
+ randomIdent = 42
+ )
+ var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ )
+
+ tests := []struct {
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ badHeaderChecksum bool
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ options []byte
+ replyOptions []byte // if succeeds, reply should look like this
+ shouldFail bool
+ expectErrorICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ paramProblemPointer uint8
+ }{
+ {
+ name: "valid no options",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ },
+ {
+ name: "bad header checksum",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ badHeaderChecksum: true,
+ shouldFail: true,
+ },
+ // The TTL tests check that we are not rejecting an incoming packet
+ // with a zero or one TTL, which has been a point of confusion in the
+ // past as RFC 791 says: "If this field contains the value zero, then the
+ // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
+ // for the case of the destination host, stating as follows.
+ //
+ // A host MUST NOT send a datagram with a Time-to-Live (TTL)
+ // value of zero.
+ //
+ // A host MUST NOT discard a datagram just because it was
+ // received with TTL less than 2.
+ {
+ name: "zero TTL",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 0,
+ },
+ {
+ name: "one TTL",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ },
+ {
+ name: "End options",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{0, 0, 0, 0},
+ replyOptions: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "NOP options",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 1, 1},
+ replyOptions: []byte{1, 1, 1, 1},
+ },
+ {
+ name: "NOP and End options",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 0, 0},
+ replyOptions: []byte{1, 1, 0, 0},
+ },
+ {
+ name: "bad header length",
+ headerLength: header.IPv4MinimumSize - 1,
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ },
+ {
+ name: "bad total length (0)",
+ maxTotalLength: 0,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ },
+ {
+ name: "bad total length (ip - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ },
+ {
+ name: "bad total length (ip + icmp - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ },
+ {
+ name: "bad protocol",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: 99,
+ TTL: ttl,
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4DstUnreachable,
+ ICMPCode: header.ICMPv4ProtoUnreachable,
+ },
+ {
+ name: "timestamp option overflow",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ },
+ {
+ name: "timestamp option overflow full",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0xF1,
+ // ^ Counter full (15/0xF)
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 3,
+ replyOptions: []byte{},
+ },
+ {
+ name: "unknown option",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{10, 4, 9, 0},
+ // ^^
+ // The unknown option should be stripped out of the reply.
+ replyOptions: []byte{},
+ },
+ {
+ name: "bad option - length 0",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 0, 9, 0,
+ // ^
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ name: "bad option - length big",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 9, 9, 0,
+ // ^
+ // There are only 8 bytes allocated to options so 9 bytes of timestamp
+ // space is not possible. (Second byte)
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ // This tests for some linux compatible behaviour.
+ // The ICMP pointer returned is 22 for Linux but the
+ // error is actually in spot 21.
+ name: "bad option - length bad",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ // Timestamps are in multiples of 4 or 8 but never 7.
+ // The option space should be padded out.
+ options: []byte{
+ 68, 7, 5, 0,
+ // ^ ^ Linux points here which is wrong.
+ // | Not a multiple of 4
+ 1, 2, 3,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "multiple type 0 with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 24, 21, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 24, 25, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // The timestamp area is full so add to the overflow count.
+ name: "multiple type 1 timestamps",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 21, 0x11,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ // Overflow count is the top nibble of the 4th byte.
+ replyOptions: []byte{
+ 68, 20, 21, 0x21,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ },
+ {
+ name: "multiple type 1 timestamps with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 28, 21, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 28, 29, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 192, 168, 1, 58, // New IP Address.
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // Needs 8 bytes for a type 1 timestamp but there are only 4 free.
+ name: "bad timer element alignment",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 17, 0x01,
+ // ^^ ^^ 20 byte area, next free spot at 17.
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ // End of option list with illegal option after it, which should be ignored.
+ {
+ name: "end of options list",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 10, 3, 99,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0, // 3 bytes unknown option
+ }, // ^ End of options hides following bytes.
+ },
+ {
+ // Timestamp with a size too small.
+ name: "timestamp truncated",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{68, 1, 0, 0},
+ // ^ Smallest possible is 8.
+ shouldFail: true,
+ },
+ {
+ name: "single record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 4, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "multiple record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 20, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "single record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Unlike timestamp, this should just succeed.
+ name: "multiple record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 24, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm linux bug for bug compatibility.
+ // Linux returns slot 22 but the error is in slot 21.
+ name: "multiple record route with not enough room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 8, 8, // 3 byte header
+ // ^ ^ Linux points here. We must too.
+ // | Not enough room. 1 byte free, need 4.
+ 1, 2, 3, 4,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ replyOptions: []byte{},
+ },
+ {
+ name: "duplicate record route",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, 0, // pad
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 7,
+ replyOptions: []byte{},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ Clock: clock,
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ }
+ // Advance the clock by some unimportant amount to make
+ // sure it's all set up.
+ clock.Advance(time.Millisecond * 0x10203040)
+
+ // Default routes for IPv4 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Round up the header size to the next multiple of 4 as RFC 791, page 11
+ // says: "Internet Header Length is the length of the internet header
+ // in 32 bit words..." and on page 23: "The internet header padding is
+ // used to ensure that the internet header ends on a 32 bit boundary."
+ ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1)
+
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+
+ // Specify ident/seq to make sure we get the same in the response.
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ if test.maxTotalLength < totalLen {
+ totalLen = test.maxTotalLength
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHeaderLength),
+ TotalLength: totalLen,
+ Protocol: test.transportProtocol,
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: ipv4Addr.Address,
+ })
+ if n := copy(ip.Options(), test.options); n != len(test.options) {
+ t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options))
+ }
+ // Override the correct value if the test case specified one.
+ if test.headerLength != 0 {
+ ip.SetHeaderLength(test.headerLength)
+ }
+ ip.SetChecksum(0)
+ ipHeaderChecksum := ip.CalculateChecksum()
+ if test.badHeaderChecksum {
+ ipHeaderChecksum += 42
+ }
+ ip.SetChecksum(^ipHeaderChecksum)
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e.Read()
+ if !ok {
+ if test.shouldFail {
+ if test.expectErrorICMP {
+ t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode)
+ }
+ return // Expected silent failure.
+ }
+ t.Fatal("expected ICMP echo reply missing")
+ }
+
+ // We didn't expect a packet. Register our surprise but carry on to
+ // provide more information about what we got.
+ if test.shouldFail && !test.expectErrorICMP {
+ t.Error("unexpected packet response")
+ }
+
+ // Check the route that brought the packet to us.
+ if reply.Route.LocalAddress != ipv4Addr.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
+ }
+ if reply.Route.RemoteAddress != remoteIPv4Addr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
+ }
+
+ // Make sure it's all in one buffer for checker.
+ replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader()))
+
+ // At this stage we only know it's probably an IP+ICMP header so verify
+ // that much.
+ checker.IPv4(t, replyIPHeader,
+ checker.SrcAddr(ipv4Addr.Address),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ ),
+ )
+
+ // Don't proceed any further if the checker found problems.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // OK it's ICMP. We can safely look at the type now.
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
+ switch replyICMPHeader.Type() {
+ case header.ICMPv4ParamProblem:
+ if !test.shouldFail {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer())
+ }
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Pointer(test.paramProblemPointer),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ return
+ case header.ICMPv4DstUnreachable:
+ if !test.shouldFail {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted no response",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ return
+ case header.ICMPv4EchoReply:
+ if test.shouldFail {
+ if !test.expectErrorICMP {
+ t.Error("got Echo Reply packet, want no response")
+ } else {
+ t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode)
+ }
+ }
+ // If the IP options change size then the packet will change size, so
+ // some IP header fields will need to be adjusted for the checks.
+ sizeChange := len(test.replyOptions) - len(test.options)
+
+ checker.IPv4(t, replyIPHeader,
+ checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
+ checker.IPv4Options(test.replyOptions),
+ checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Seq(randomSequence),
+ checker.ICMPv4Ident(randomIdent),
+ ),
+ )
+ default:
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem)
+ }
+ })
+ }
}
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) {
- t.Helper()
- // Make a complete array of the sourcePacketInfo packet.
- source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
- source = append(source, sourcePacketInfo.Header.View()...)
- source = append(source, sourcePacketInfo.Data.ToView()...)
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // Make a complete array of the sourcePacket packet.
+ source := header.IPv4(packets[0].NetworkHeader().View())
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
+ source = append(source, vv.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -127,350 +776,925 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn
sourceCopy.SetChecksum(0)
sourceCopy.SetFlagsFragmentOffset(0, 0)
sourceCopy.SetTotalLength(0)
- var offset uint16
// Build up an array of the bytes sent.
- var reassembledPayload []byte
+ var reassembledPayload buffer.VectorisedView
for i, packet := range packets {
// Confirm that the packet is valid.
- allBytes := packet.Header.View().ToVectorisedView()
- allBytes.Append(packet.Data)
- ip := header.IPv4(allBytes.ToView())
- if !ip.IsValid(len(ip)) {
- t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
+ fragmentIPHeader := header.IPv4(allBytes.ToView())
+ if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader))
+ }
+ if got := len(fragmentIPHeader); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu)
}
- if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
- t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ if got := fragmentIPHeader.TransportProtocol(); got != proto {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
}
- if got, want := len(ip), int(mtu); got > want {
- t.Errorf("fragment is too large, got %d want %d", got, want)
+ if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
}
- if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
}
- if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
- t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
+ return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
}
- if i < len(packets)-1 {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
+ if wantFragments[i].more {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset)
} else {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset)
}
- reassembledPayload = append(reassembledPayload, ip.Payload()...)
- offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ reassembledPayload.AppendView(packet.TransportHeader().View())
+ reassembledPayload.Append(packet.Data)
// Clear out the checksum and length from the ip because we can't compare
// it.
- sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
- if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
- t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
}
- expected := source[source.HeaderLength():]
- if !bytes.Equal(reassembledPayload, expected) {
- t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+
+ expected := buffer.View(source[source.HeaderLength():])
+ if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
}
-}
-type errorChannel struct {
- *channel.Endpoint
- Ch chan stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
+ return nil
}
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan stack.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
}
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
+var fragmentationTests = []struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ payloadSize int
+ wantFragments []fragmentInfo
+}{
+ {
+ description: "No fragmentation",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 744, more: false},
+ },
+ },
+ {
+ description: "Fragmented with the minimum mtu",
+ mtu: header.IPv4MinimumMTU,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv4MinimumMTU + 1,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: nil,
+ transportHeaderLength: 100,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso none",
+ mtu: 1280,
+ gso: &stack.GSO{Type: stack.GSONone},
+ transportHeaderLength: 0,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 144, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 100,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 44, more: false},
+ },
+ },
+ {
+ description: "Fragmented with MTU smaller than header",
+ mtu: 300,
+ gso: nil,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 280, more: true},
+ {offset: 280, payloadSize: 280, more: true},
+ {offset: 560, payloadSize: 280, more: true},
+ {offset: 840, payloadSize: 280, more: true},
+ {offset: 1120, payloadSize: 280, more: true},
+ {offset: 1400, payloadSize: 100, more: false},
+ },
+ },
}
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
+func TestFragmentationWritePacket(t *testing.T) {
+ const ttl = 42
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ source := pkt.Clone()
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != nil {
+ t.Fatalf("r.WritePacket(_, _, _) = %s", err)
+ }
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
}
- return nextError
}
-type context struct {
- stack.Route
- linkEP *errorChannel
-}
-
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
- // Make the packet and write it.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- s.AddAddress(1, ipv4.ProtocolNumber, src)
- {
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
- }
- return context{
- Route: r,
- linkEP: ep,
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ writePacketsTests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
}
-}
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
-func TestFragmentation(t *testing.T) {
- var manyPayloadViewsSizes [1000]int
- for i := range manyPayloadViewsSizes {
- manyPayloadViewsSizes[i] = 7
- }
- fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
- }{
- {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
- {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
- }
-
- for _, ft := range fragTests {
- t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := stack.PacketBuffer{
- Header: hdr,
- // Save the source payload because WritePacket will modify it.
- Data: payload.Clone(nil),
- }
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
- if err != nil {
- t.Errorf("err got %v, want %v", err, nil)
- }
+ for _, test := range writePacketsTests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
- var results []stack.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
- }
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
- }
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err)
+ }
+ if n != wantTotalPackets {
+ t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
}
- compareFragments(t, results, source, ft.mtu)
})
}
}
-// TestFragmentationErrors checks that errors are returned from write packet
+// TestFragmentationErrors checks that errors are returned from WritePacket
// correctly.
func TestFragmentationErrors(t *testing.T) {
- fragTests := []struct {
+ const ttl = 42
+
+ tests := []struct {
description string
mtu uint32
- hdrLength int
- payloadViewsSizes []int
- packetCollectorErrors []*tcpip.Error
+ transportHeaderLength int
+ payloadSize int
+ allowPackets int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
}{
- {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {
+ description: "No frag",
+ mtu: 2000,
+ payloadSize: 1000,
+ transportHeaderLength: 0,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 500,
+ payloadSize: 1000,
+ transportHeaderLength: 0,
+ allowPackets: 0,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 500,
+ payloadSize: 1000,
+ transportHeaderLength: 0,
+ allowPackets: 1,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag MTU smaller than header",
+ mtu: 500,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 4,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error when MTU is smaller than IPv4 minimum MTU",
+ mtu: header.IPv4MinimumMTU - 1,
+ transportHeaderLength: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
+ },
}
- for _, ft := range fragTests {
+ for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
- for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
- if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
- t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
- }
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
}
- // We only need to check that last error because all the ones before are
- // nil.
- if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
- t.Errorf("err got %v, want %v", got, want)
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
}
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
- t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
}
})
}
}
func TestInvalidFragments(t *testing.T) {
- // These packets have both IHL and TotalLength set to 0.
- testCases := []struct {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 6
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ autoChecksum bool // if true, the Checksum field will be overwritten.
+ }
+
+ tests := []struct {
name string
- packets [][]byte
+ fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
}{
{
- "ihl_totallen_zero_valid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset non-zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments,
+ FragmentOffset: 59776,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
+ },
+ {
+ name: "IHL and TotalLength zero, FragmentOffset zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
+ },
+ {
+ // Payload 17 octets and Fragment offset 65520
+ // Leading to the fragment end to be past 65536.
+ name: "fragment ends past 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 17,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(17),
+ autoChecksum: true,
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ },
+ {
+ // Payload 16 octets and fragment offset 65520
+ // Leading to the fragment end to be exactly 65536.
+ name: "fragment ends exactly at 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(16),
+ autoChecksum: true,
+ },
+ },
+ wantMalformedIPPackets: 0,
+ wantMalformedFragments: 0,
+ },
+ {
+ name: "IHL less than IPv4 minimum size",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 1944,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize - 12,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 1,
- 0,
+ wantMalformedIPPackets: 2,
+ wantMalformedFragments: 0,
},
{
- "ihl_totallen_zero_invalid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "fragment with short TotalLength and extra payload",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 28816,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 4,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 1,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- // Total Length of 37(20 bytes IP header + 17 bytes of
- // payload)
- // Frag Offset of 0x1ffe = 8190*8 = 65520
- // Leading to the fragment end to be past 65535.
- "ihl_totallen_valid_invalid_frag_offset_1",
- [][]byte{
- {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "multiple fragments with More Fragments flag set to false",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 128,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
- // The following 3 tests were found by running a fuzzer and were
- // triggering a panic in the IPv4 reassembler code.
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ },
+ })
+ e := channel.New(0, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize + len(f.payload)
+ hdr := buffer.NewPrependable(pktSize)
+
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+ copy(ip[header.IPv4MinimumSize:], f.payload)
+
+ if f.autoChecksum {
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ }
+
+ vv := hdr.View().ToVectorisedView()
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
+ t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
+ }
+ })
+ }
+}
+
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 99
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
{
- "ihl_less_than_ipv4_minimum_size_1",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
},
- 2,
- 0,
+ expectICMP: true,
},
{
- "ihl_less_than_ipv4_minimum_size_2",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
},
- 2,
- 0,
+ expectICMP: true,
},
{
- "ihl_less_than_ipv4_minimum_size_3",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
},
- 2,
- 0,
+ expectICMP: false,
},
{
- "fragment_with_short_total_len_extra_payload",
- [][]byte{
- {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
},
- 1,
- 1,
+ expectICMP: true,
},
{
- "multiple_fragments_with_more_fragments_set_to_false",
- [][]byte{
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
},
- 1,
- 1,
+ expectICMP: true,
},
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- const nicID tcpip.NICID = 42
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
},
+ Clock: clock,
})
+ e := channel.New(1, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize
+ hdr := buffer.NewPrependable(pktSize)
+
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
- var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
- var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
- ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicID, sniffer.New(ep))
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
- for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
})
+
+ if firstFragmentSent == nil && ip.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(header.IPv4ProtocolNumber, pkt)
}
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
- t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
+ clock.Advance(ipv4.ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
}
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
- t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
}
+
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(firstFragmentSent)),
+ ),
+ )
})
}
}
@@ -478,12 +1702,16 @@ func TestInvalidFragments(t *testing.T) {
// TestReceiveFragments feeds fragments in through the incoming packet path to
// test reassembly
func TestReceiveFragments(t *testing.T) {
- const addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- const addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- const nicID = 1
+ const (
+ nicID = 1
+
+ addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
+ addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
+ addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ )
// Build and return a UDP header containing payload.
- udpGen := func(payloadLen int, multiplier uint8) buffer.View {
+ udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View {
payload := buffer.NewView(payloadLen)
for i := 0; i < len(payload); i++ {
payload[i] = uint8(i) * multiplier
@@ -499,20 +1727,32 @@ func TestReceiveFragments(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
sum = header.Checksum(payload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
return hdr.View()
}
// UDP header plus a payload of 0..256
- ipv4Payload1 := udpGen(256, 1)
- udpPayload1 := ipv4Payload1[header.UDPMinimumSize:]
+ ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2)
+ udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:]
+ ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2)
+ udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:]
// UDP header plus a payload of 0..256 in increments of 2.
- ipv4Payload2 := udpGen(128, 2)
- udpPayload2 := ipv4Payload2[header.UDPMinimumSize:]
+ ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2)
+ udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:]
+ // UDP header plus a payload of 0..256 in increments of 3.
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 791 section 3.1 page 14).
+ ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
+ udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
+ // Used to test the max reassembled payload length (65,535 octets).
+ ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
+ udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
id uint16
flags uint8
fragmentOffset uint16
@@ -528,22 +1768,40 @@ func TestReceiveFragments(t *testing.T) {
name: "No fragmentation",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 0,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "No fragmentation with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2,
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
},
{
name: "More fragments without payload",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
expectedPayloads: nil,
@@ -552,10 +1810,12 @@ func TestReceiveFragments(t *testing.T) {
name: "Non-zero fragment offset without payload",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 8,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
expectedPayloads: nil,
@@ -564,34 +1824,108 @@ func TestReceiveFragments(t *testing.T) {
name: "Two fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments out of order",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload3Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload3Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2[:63],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 63,
+ payload: ipv4Payload3Addr1ToAddr2[63:],
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: nil,
},
{
name: "Second fragment has MoreFlags set",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
},
expectedPayloads: nil,
@@ -600,16 +1934,20 @@ func TestReceiveFragments(t *testing.T) {
name: "Two fragments with different IDs",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
},
expectedPayloads: nil,
@@ -618,31 +1956,113 @@ func TestReceiveFragments(t *testing.T) {
name: "Two interleaved fragmented packets",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload2[:64],
+ payload: ipv4Payload2Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload2[64:],
+ payload: ipv4Payload2Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr3ToAddr2[:32],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 32,
+ payload: ipv4Payload1Addr3ToAddr2[32:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
+ },
+ {
+ name: "Fragment without followup",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
},
},
- expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
}
@@ -650,8 +2070,8 @@ func TestReceiveFragments(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
// Setup a stack and endpoint.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
if err := s.CreateNIC(nicID, e); err != nil {
@@ -691,16 +2111,17 @@ func TestReceiveFragments(t *testing.T) {
FragmentOffset: frag.fragmentOffset,
TTL: 64,
Protocol: uint8(header.UDPProtocolNumber),
- SrcAddr: addr1,
- DstAddr: addr2,
+ SrcAddr: frag.srcAddr,
+ DstAddr: frag.dstAddr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
vv := hdr.View().ToVectorisedView()
vv.AppendView(frag.payload)
- e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- })
+ }))
}
if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
@@ -723,3 +2144,394 @@ func TestReceiveFragments(t *testing.T) {
})
}
}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ // Parameterize the tests to run with both WritePacket and WritePackets.
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
+ }
+ {
+ mask := tcpip.AddressMask(header.IPv4Broadcast)
+ subnet, err := tcpip.NewSubnet(dst, mask)
+ if err != nil {
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
+ }
+ return rt
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
+ TTL: ipv4.DefaultTTL,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(1, defaultMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a ARP request since link address resolution should be
+ // performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != arp.ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(p.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
+ }
+ }
+
+ // Send an ARP reply to complete link address resolution.
+ {
+ hdr := buffer.View(make([]byte, header.ARPSize))
+ packet := header.ARP(hdr)
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), host2NICLinkAddr)
+ copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
+ copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
+ copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
+ e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 3f71fc520..0ac24a6fb 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -5,14 +5,18 @@ package(licenses = ["notice"])
go_library(
name = "ipv6",
srcs = [
+ "dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "ndp.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
@@ -32,13 +36,16 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
new file mode 100644
index 000000000..09ba133b1
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
+
+package ipv6
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[DHCPv6NoConfiguration-1]
+ _ = x[DHCPv6ManagedAddress-2]
+ _ = x[DHCPv6OtherConfigurations-3]
+}
+
+const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations"
+
+var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66}
+
+func (i DHCPv6ConfigurationFromNDPRA) String() string {
+ i -= 1
+ if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) {
+ return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")"
+ }
+ return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]]
+}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index bdf3a0d25..3c15e41a7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -27,7 +27,7 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
return
@@ -39,8 +39,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// is truncated, which would cause IsValid to return false.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
- // original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -67,20 +68,76 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
+// getLinkAddrOption searches NDP options for a given link address option using
+// the provided getAddr function as a filter. Returns the link address if
+// found; otherwise, returns the zero link address value. Also returns true if
+// the options are valid as per the wire format, false otherwise.
+func getLinkAddrOption(it header.NDPOptionIterator, getAddr func(header.NDPOption) tcpip.LinkAddress) (tcpip.LinkAddress, bool) {
+ var linkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ return "", false
+ }
+ if done {
+ break
+ }
+ if addr := getAddr(opt); len(addr) != 0 {
+ // No RFCs define what to do when an NDP message has multiple Link-Layer
+ // Address options. Since no interface can have multiple link-layer
+ // addresses, we consider such messages invalid.
+ if len(linkAddr) != 0 {
+ return "", false
+ }
+ linkAddr = addr
+ }
+ }
+ return linkAddr, true
+}
+
+// getSourceLinkAddr searches NDP options for the source link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getSourceLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if src, ok := opt.(header.NDPSourceLinkLayerAddressOption); ok {
+ return src.EthernetAddress()
+ }
+ return ""
+ })
+}
+
+// getTargetLinkAddr searches NDP options for the target link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if dst, ok := opt.(header.NDPTargetLinkLayerAddressOption); ok {
+ return dst.EthernetAddress()
+ }
+ return ""
+ })
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
if !ok {
received.Invalid.Increment()
return
}
h := header.ICMPv6(v)
- iph := header.IPv6(netHeader)
+ iph := header.IPv6(pkt.NetworkHeader().View())
// Validate ICMPv6 checksum before processing the packet.
//
@@ -113,8 +170,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := header.ICMPv6(hdr).MTU()
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
@@ -125,13 +185,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
@@ -141,22 +203,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// NDP messages cannot be fragmented. Also note that in the common case NDP
// datagrams are very small and ToView() will not incur allocations.
ns := header.NDPNeighborSolicit(payload.ToView())
- it, err := ns.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NS option, drop the packet.
+ targetAddr := ns.TargetAddress()
+
+ // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
+ // address.
+ if header.IsV6MulticastAddress(targetAddr) {
received.Invalid.Increment()
return
}
- targetAddr := ns.TargetAddress()
- s := r.Stack()
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now, drop this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// If the target address is tentative and the source of the packet is a
// unicast (specified) address, then the source of the packet is
// attempting to perform address resolution on the target. In this case,
@@ -169,7 +225,20 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if r.RemoteAddress == header.IPv6Any {
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
}
// Do not handle neighbor solicitations targeted to an address that is
@@ -181,48 +250,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// so the packet is processed as defined in RFC 4861, as per RFC 4862
// section 5.4.3.
- // Is the NS targetting us?
- if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ // Is the NS targeting us?
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
- // If the NS message contains the Source Link-Layer Address option, update
- // the link address cache with the value of the option.
- //
- // TODO(b/148429853): Properly process the NS message and do Neighbor
- // Unreachability Detection.
var sourceLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
+ {
+ it, err := ns.Options().Iter(false /* check */)
if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
+ // Options are not valid as per the wire format, silently drop the
+ // packet.
+ received.Invalid.Increment()
+ return
}
- switch opt := opt.(type) {
- case header.NDPSourceLinkLayerAddressOption:
- // No RFCs define what to do when an NS message has multiple Source
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(sourceLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
-
- sourceLinkAddr = opt.EthernetAddress()
+ sourceLinkAddr, ok = getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
}
}
- unspecifiedSource := r.RemoteAddress == header.IPv6Any
-
// As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
// NOT be included when the source IP address is the unspecified address.
// Otherwise, on link layers that have addresses this option MUST be
// included in multicast solicitations and SHOULD be included in unicast
// solicitations.
+ unspecifiedSource := r.RemoteAddress == header.IPv6Any
if len(sourceLinkAddr) == 0 {
if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
received.Invalid.Increment()
@@ -231,55 +286,88 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
} else if unspecifiedSource {
received.Invalid.Increment()
return
+ } else if e.nud != nil {
+ e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
- }
-
- // ICMPv6 Neighbor Solicit messages are always sent to
- // specially crafted IPv6 multicast addresses. As a result, the
- // route we end up with here has as its LocalAddress such a
- // multicast address. It would be nonsense to claim that our
- // source address is a multicast address, so we manually set
- // the source address to the target address requested in the
- // solicit message. Since that requires mutating the route, we
- // must first clone it.
- r := r.Clone()
- defer r.Release()
- r.LocalAddress = targetAddr
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
+ }
- // As per RFC 4861 section 7.2.4, if the the source of the solicitation is
- // the unspecified address, the node MUST set the Solicited flag to zero and
- // multicast the advertisement to the all-nodes address.
- solicited := true
+ // As per RFC 4861 section 7.1.1:
+ // A node MUST silently discard any received Neighbor Solicitation
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP source address is the unspecified address, the IP
+ // destination address is a solicited-node multicast address.
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ received.Invalid.Increment()
+ return
+ }
+
+ // As per RFC 4861 section 7.2.4:
+ //
+ // If the source of the solicitation is the unspecified address, the node
+ // MUST [...] and multicast the advertisement to the all-nodes address.
+ //
+ remoteAddr := r.RemoteAddress
if unspecifiedSource {
- solicited = false
- r.RemoteAddress = header.IPv6AllNodesMulticastAddress
+ remoteAddr = header.IPv6AllNodesMulticastAddress
}
- // If the NS has a source link-layer option, use the link address it
- // specifies as the remote link address for the response instead of the
- // source link address of the packet.
+ // Even if we were able to receive a packet from some remote, we may not
+ // have a route to it - the remote may be blocked via routing rules. We must
+ // always consult our routing table and find a route to the remote before
+ // sending any packet.
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), targetAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ // If the NS has a source link-layer option, resolve the route immediately
+ // to avoid querying the neighbor table when the neighbor entry was updated
+ // as probing the neighbor table for a link address will transition the
+ // entry's state from stale to delay.
+ //
+ // Note, if the source link address is unspecified and this is a unicast
+ // solicitation, we may need to perform neighbor discovery to send the
+ // neighbor advertisement response. This is expected as per RFC 4861 section
+ // 7.2.4:
+ //
+ // Because unicast Neighbor Solicitations are not required to include a
+ // Source Link-Layer Address, it is possible that a node sending a
+ // solicited Neighbor Advertisement does not have a corresponding link-
+ // layer address for its neighbor in its Neighbor Cache. In such
+ // situations, a node will first have to use Neighbor Discovery to
+ // determine the link-layer address of its neighbor (i.e., send out a
+ // multicast Neighbor Solicitation).
//
- // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link
- // address cache for the right destination link address instead of manually
- // patching the route with the remote link address if one is specified in a
- // Source Link-Layer Address option.
if len(sourceLinkAddr) != 0 {
- r.RemoteLinkAddress = sourceLinkAddr
+ r.ResolveWith(sourceLinkAddr)
}
optsSerializer := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
+ header.NDPTargetLinkLayerAddressOption(e.nic.LinkAddress()),
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
- packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize,
+ })
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
- na.SetSolicitedFlag(solicited)
+
+ // As per RFC 4861 section 7.2.4:
+ //
+ // If the source of the solicitation is the unspecified address, the node
+ // MUST set the Solicited flag to zero and [..]. Otherwise, the node MUST
+ // set the Solicited flag to one and [..].
+ //
+ na.SetSolicitedFlag(!unspecifiedSource)
na.SetOverrideFlag(true)
na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
+ na.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
@@ -288,9 +376,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- }); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -298,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
received.Invalid.Increment()
return
}
@@ -308,28 +394,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// 5, NDP messages cannot be fragmented. Also note that in the common case
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
- it, err := na.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NA option, drop the packet.
- received.Invalid.Increment()
- return
- }
-
targetAddr := na.TargetAddress()
- stack := r.Stack()
-
- if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now short-circuit this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- stack.DupTentativeAddrDetected(e.nicID, targetAddr)
+ //
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
+ return
+ }
+
+ it, err := na.Options().Iter(false /* check */)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
return
}
@@ -342,58 +434,59 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// TODO(b/143147598): Handle the scenario described above. Also inform the
// netstack integration that a duplicate address was detected outside of
// DAD.
+ targetLinkAddr, ok := getTargetLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
- //
- // TODO(b/148429853): Properly process the NA message and do Neighbor
- // Unreachability Detection.
- var targetLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
- }
-
- switch opt := opt.(type) {
- case header.NDPTargetLinkLayerAddressOption:
- // No RFCs define what to do when an NA message has multiple Target
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(targetLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
-
- targetLinkAddr = opt.EthernetAddress()
+ if e.nud == nil {
+ if len(targetLinkAddr) != 0 {
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
}
+ return
}
- if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
- }
+ e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ Solicited: na.SolicitedFlag(),
+ Override: na.OverrideFlag(),
+ IsRouter: na.RouterFlag(),
+ })
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
- icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize)
+ icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
if !ok {
received.Invalid.Increment()
return
}
- pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
- packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+
+ // As per RFC 4291 section 2.7, multicast addresses must not be used as
+ // source addresses in IPv6 packets.
+ localAddr := r.LocalAddress
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ localAddr = ""
+ }
+
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
+ Data: pkt.Data,
+ })
+ packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
- Header: hdr,
- Data: pkt.Data,
- }); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -415,27 +508,75 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6RouterSolicit:
received.RouterSolicit.Increment()
- if !isNDPValid() {
+
+ //
+ // Validate the RS as per RFC 4861 section 6.1.1.
+ //
+
+ // Is the NDP payload of sufficient size to hold a Router Solictation?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
received.Invalid.Increment()
return
}
- case header.ICMPv6RouterAdvert:
- received.RouterAdvert.Increment()
+ stack := r.Stack()
- // Is the NDP payload of sufficient size to hold a Router
- // Advertisement?
- if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ // Is the networking stack operating as a router?
+ if !stack.Forwarding(ProtocolNumber) {
+ // ... No, silently drop the packet.
+ received.RouterOnlyPacketsDroppedByHost.Increment()
+ return
+ }
+
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
+ rs := header.NDPRouterSolicit(payload.ToView())
+ it, err := rs.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
received.Invalid.Increment()
return
}
- routerAddr := iph.SourceAddress()
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+
+ // If the RS message has the source link layer option, update the link
+ // address cache with the link address for the source of the message.
+ if len(sourceLinkAddr) != 0 {
+ // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, it SHOULD be included on link layers that have addresses.
+ if r.RemoteAddress == header.IPv6Any {
+ received.Invalid.Increment()
+ return
+ }
+
+ if e.nud != nil {
+ // A RS with a specified source IP address modifies the NUD state
+ // machine in the same way a reachability probe would.
+ e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+ }
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
//
// Validate the RA as per RFC 4861 section 6.1.2.
//
+ // Is the NDP payload of sufficient size to hold a Router Advertisement?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ routerAddr := iph.SourceAddress()
+
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
// ...No, silently drop the packet.
@@ -443,16 +584,18 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
return
}
- // The remainder of payload must be only the router advertisement, so
- // payload.ToView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and ToView() will not incur allocations.
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
ra := header.NDPRouterAdvert(payload.ToView())
- opts := ra.Options()
+ it, err := ra.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
- // Are options valid as per the wire format?
- if _, err := opts.Iter(true); err != nil {
- // ...No, silently drop the packet.
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
received.Invalid.Increment()
return
}
@@ -462,12 +605,33 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// as RFC 4861 section 6.1.2 is concerned.
//
- // Tell the NIC to handle the RA.
- stack := r.Stack()
- rxNICID := r.NICID()
- stack.HandleNDPRA(rxNICID, routerAddr, ra)
+ // If the RA has the source link layer option, update the link address
+ // cache with the link address for the advertised router.
+ if len(sourceLinkAddr) != 0 && e.nud != nil {
+ e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+
+ e.mu.Lock()
+ e.mu.ndp.handleRA(routerAddr, ra)
+ e.mu.Unlock()
case header.ICMPv6RedirectMsg:
+ // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating
+ // this redirect message, as per RFC 4871 section 7.3.3:
+ //
+ // "A Neighbor Cache entry enters the STALE state when created as a
+ // result of receiving packets other than solicited Neighbor
+ // Advertisements (i.e., Router Solicitations, Router Advertisements,
+ // Redirects, and Neighbor Solicitations). These packets contain the
+ // link-layer address of either the sender or, in the case of Redirect,
+ // the redirection target. However, receipt of these link-layer
+ // addresses does not confirm reachability of the forward-direction path
+ // to that node. Placing a newly created Neighbor Cache entry for which
+ // the link-layer address is known in the STALE state provides assurance
+ // that path failures are detected quickly. In addition, should a cached
+ // link-layer address be modified due to receiving one of the above
+ // messages, the state SHOULD also be set to STALE to provide prompt
+ // verification that the path to the new link-layer address is working."
received.RedirectMsg.Increment()
if !isNDPValid() {
received.Invalid.Increment()
@@ -479,20 +643,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
}
}
-const (
- ndpSolicitedFlag = 1 << 6
- ndpOverrideFlag = 1 << 5
-
- ndpOptSrcLinkAddr = 1
- ndpOptDstLinkAddr = 2
-
- icmpV6FlagOffset = 4
- icmpV6OptOffset = 24
- icmpV6LengthOffset = 25
-)
-
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -501,40 +651,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
- snaddr := header.SolicitedNodeAddr(addr)
-
- // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
- // route here. Note, we would need the nicID to do this properly so the right
- // NIC (associated to linkEP) is used to send the NDP NS message.
- r := &stack.Route{
- LocalAddress: localAddr,
- RemoteAddress: snaddr,
- RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ remoteAddr := targetAddr
+ if len(remoteLinkAddr) == 0 {
+ remoteAddr = header.SolicitedNodeAddr(targetAddr)
+ remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr)
}
- hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- copy(pkt[icmpV6OptOffset-len(addr):], addr)
- pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
- pkt[icmpV6LengthOffset] = 1
- copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
-
- length := uint16(hdr.UsedLength())
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
- // TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{
- Header: hdr,
+ r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+ r.ResolveWith(remoteLinkAddr)
+
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
+ }
+ neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize,
})
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
+ packet.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns.SetTargetAddress(targetAddr)
+ ns.Options().Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ stat := p.stack.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt); err != nil {
+ stat.Dropped.Increment()
+ return err
+ }
+
+ stat.NeighborSolicit.Increment()
+ return nil
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -544,3 +700,192 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
return tcpip.LinkAddress([]byte(nil)), false
}
+
+// ======= ICMP Error packet generation =========
+
+// icmpReason is a marker interface for IPv6 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonParameterProblem is an error during processing of extension headers
+// or the fixed header defined in RFC 4443 section 3.4.
+type icmpReasonParameterProblem struct {
+ code header.ICMPv6Code
+
+ // respondToMulticast indicates that we are sending a packet that falls under
+ // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondToMulticast bool
+
+ // pointer is defined in the RFC 4443 setion 3.4 which reads:
+ //
+ // Pointer Identifies the octet offset within the invoking packet
+ // where the error was detected.
+ //
+ // The pointer will point beyond the end of the ICMPv6
+ // packet if the field in error is beyond what can fit
+ // in the maximum size of an ICMPv6 error message.
+ pointer uint32
+}
+
+func (*icmpReasonParameterProblem) isICMPReason() {}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv6 and sends it.
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ // Only send ICMP error if the address is not a multicast v6
+ // address and the source is not the unspecified address.
+ //
+ // There are exceptions to this rule.
+ // See: point e.3) RFC 4443 section-2.4
+ //
+ // (e) An ICMPv6 error message MUST NOT be originated as a result of
+ // receiving the following:
+ //
+ // (e.1) An ICMPv6 error message.
+ //
+ // (e.2) An ICMPv6 redirect message [IPv6-DISC].
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ //
+ var allowResponseToMulticast bool
+ if reason, ok := reason.(*icmpReasonParameterProblem); ok {
+ allowResponseToMulticast = reason.respondToMulticast
+ }
+
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ stats := p.stack.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
+ if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
+ // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
+ // Unfortunately at this time ICMP Packets do not have a transport
+ // header separated out. It is in the Data part so we need to
+ // separate it out now. We will just pretend it is a minimal length
+ // ICMP packet as we don't really care if any later bits of a
+ // larger ICMP packet are in the header view or in the Data view.
+ transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize)
+ if !ok {
+ return nil
+ }
+ typ := header.ICMPv6(transport).Type()
+ if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
+ return nil
+ }
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(route.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ available := int(mtu) - headerLen
+ if available < header.IPv6MinimumSize {
+ return nil
+ }
+ payloadLen := network.Size() + transport.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := network.ToVectorisedView()
+ payload.AppendView(transport)
+ payload.Append(pkt.Data)
+ payload.CapLength(payloadLen)
+
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+ newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+
+ icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ icmpHdr.SetType(header.ICMPv6ParamProblem)
+ icmpHdr.SetCode(reason.code)
+ icmpHdr.SetTypeSpecific(reason.pointer)
+ counter = sent.ParamProblem
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
+ counter = sent.TimeExceeded
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data))
+ if err := route.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ newPkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d412ff688..aa8b5f2e5 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,37 +16,57 @@ package ipv6
import (
"context"
+ "net"
"reflect"
"strings"
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
+ nicID = 1
+
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 30 * time.Second
)
var (
lladdr0 = header.LinkLocalAddr(linkAddr0)
lladdr1 = header.LinkLocalAddr(linkAddr1)
+ lladdr2 = header.LinkLocalAddr(linkAddr2)
)
type stubLinkEndpoint struct {
stack.LinkEndpoint
}
+func (*stubLinkEndpoint) MTU() uint32 {
+ return defaultMTU
+}
+
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+ // Indicate that resolution for link layer addresses is required to send
+ // packets over this link. This is needed so the NIC knows to allocate a
+ // neighbor table.
+ return stack.CapabilityResolutionRequired
}
func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
@@ -57,7 +77,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
return nil
}
@@ -67,7 +87,8 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
+ return stack.TransportPacketHandled
}
type stubLinkAddressCache struct {
@@ -81,16 +102,225 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco
func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
}
+type stubNUDHandler struct {
+ probeCount int
+ confirmationCount int
+}
+
+var _ stack.NUDHandler = (*stubNUDHandler)(nil)
+
+func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) {
+ s.probeCount++
+}
+
+func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) {
+ s.confirmationCount++
+}
+
+func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) {
+}
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.LinkEndpoint
+
+ nicID tcpip.NICID
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
func TestICMPCounts(t *testing.T) {
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ {
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ }
+ defer r.Release()
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ }{
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
+ }
+
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(&r, pkt)
+ }
+
+ for _, typ := range types {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
+ }
+
+ // Construct an empty ICMP packet so that
+ // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+
+ icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
+ if got, want := s.Value(), uint64(1); got != want {
+ t.Errorf("got %s = %d, want = %d", name, got, want)
+ }
+ })
+ if t.Failed() {
+ t.Logf("stats:\n%+v", s.Stats())
+ }
+ })
+ }
+}
+
+func TestICMPCountsWithNeighborCache(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: true,
})
{
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
}
@@ -102,7 +332,7 @@ func TestICMPCounts(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -111,14 +341,16 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -179,36 +411,33 @@ func TestICMPCounts(t *testing.T) {
},
}
- handleIPv6Payload := func(hdr buffer.Prependable) {
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
+ PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
+ ep.HandlePacket(&r, pkt)
}
for _, typ := range types {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
-
- handleIPv6Payload(hdr)
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -252,35 +481,34 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
s0: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
}),
s1: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
}),
}
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
+ if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
- if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
+ if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
+ if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
@@ -291,7 +519,7 @@ func newTestContext(t *testing.T) *testContext {
c.s0.SetRouteTable(
[]tcpip.Route{{
Destination: subnet0,
- NIC: 1,
+ NIC: nicID,
}},
)
subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
@@ -301,7 +529,7 @@ func newTestContext(t *testing.T) *testContext {
c.s1.SetRouteTable(
[]tcpip.Route{{
Destination: subnet1,
- NIC: 1,
+ NIC: nicID,
}},
)
@@ -325,12 +553,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
pi, _ := args.src.ReadContext(context.Background())
{
- views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
- size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
- vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{
- Data: vv,
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()),
})
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt)
}
if pi.Proto != ProtocolNumber {
@@ -342,7 +568,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
}
- ipv6 := header.IPv6(pi.Pkt.Header.View())
+ // Pull the full payload since network header. Needed for header.IPv6 to
+ // extract its payload.
+ ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader()))
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
t.Errorf("unexpected transport protocol number %d", transProto)
@@ -362,9 +590,9 @@ func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
- r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -379,14 +607,14 @@ func TestLinkResolution(t *testing.T) {
var wq waiter.Queue
ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
for {
- _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
+ _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}})
if resCh != nil {
if err != tcpip.ErrNoLinkAddress {
- t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
+ t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err)
}
for _, args := range []routeArgs{
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
@@ -402,7 +630,7 @@ func TestLinkResolution(t *testing.T) {
continue
}
if err != nil {
- t.Fatalf("ep.Write(_) = _, _, %s", err)
+ t.Fatalf("ep.Write(_) = (_, _, %s)", err)
}
break
}
@@ -427,6 +655,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
size int
extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
}{
{
name: "DstUnreachable",
@@ -483,6 +712,8 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
+ // Hosts MUST silently discard any received Router Solicitation messages.
+ routerOnly: true,
},
{
name: "RouterAdvert",
@@ -519,86 +750,133 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
+ }
+ t.Run(name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr0)
+
+ // Indicate that resolution for link layer addresses is required to
+ // send packets over this link. This is needed so the NIC knows to
+ // allocate a neighbor table.
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(checksum bool) {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ if checksum {
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ }
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
+ })
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Router only count should not have increased.
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ if !isRouter && typ.routerOnly && test.useNeighborCache {
+ // Router only count should have increased.
+ if got := routerOnly.Value(); got != 1 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
+ }
+ }
+ })
}
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- handleIPv6Payload := func(checksum bool) {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
- }
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(typ.size + extraDataLen),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
- }
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
}
})
}
@@ -699,13 +977,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Run(typ.name, func(t *testing.T) {
e := channel.New(10, 1280, linkAddr0)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
{
@@ -716,7 +994,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -724,12 +1002,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
icmpSize := size + payloadSize
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(typ)
- payloadFn(pkt.Payload())
+ icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize))
+ icmpHdr.SetType(typ)
+ payloadFn(icmpHdr.Payload())
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{}))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -740,9 +1018,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
+ e.InjectInbound(ProtocolNumber, pkt)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -754,7 +1033,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -765,13 +1044,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
@@ -876,14 +1155,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Run(typ.name, func(t *testing.T) {
e := channel.New(10, 1280, linkAddr0)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -893,21 +1172,21 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
+ icmpHdr := header.ICMPv6(hdr.Prepend(size))
+ icmpHdr.SetType(typ)
payload := buffer.NewView(payloadSize)
payloadFn(payload)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView()))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -918,9 +1197,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
})
+ e.InjectInbound(ProtocolNumber, pkt)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -932,7 +1212,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -943,13 +1223,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
@@ -958,3 +1238,573 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
+ snaddr := header.SolicitedNodeAddr(lladdr0)
+ mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
+ remoteLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedRemoteAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Unicast with no local address available",
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with no local address available",
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ }
+
+ for _, test := range tests {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ }
+ }
+
+ // We pass a test network interface to LinkAddressRequest with the same NIC
+ // ID and link endpoint used by the NIC we created earlier so that we can
+ // mock a link address request and observe the packets sent to the link
+ // endpoint even though the stack uses the real NIC.
+ if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ }
+ if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
+ }
+ if pkt.Route.LocalAddress != lladdr1 {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
+ }
+ checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
+ checker.SrcAddr(lladdr1),
+ checker.DstAddr(test.expectedRemoteAddr),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
+ ))
+ }
+}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a neighbor solicitation since link address resolution should
+ // be performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
+ ))
+ }
+
+ // Send a neighbor advertisement to complete link address resolution.
+ {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
+
+func TestCallsToNeighborCache(t *testing.T) {
+ tests := []struct {
+ name string
+ createPacket func() header.ICMPv6
+ multicast bool
+ source tcpip.Address
+ destination tcpip.Address
+ wantProbeCount int
+ wantConfirmationCount int
+ }{
+ {
+ name: "Unicast Neighbor Solicitation without source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ // "The source link-layer address option SHOULD be included in unicast
+ // solicitations." - RFC 4861 section 4.3
+ //
+ // A Neighbor Advertisement needs to be sent in response, but the
+ // Neighbor Cache shouldn't be updated since we have no useful
+ // information about the sender.
+ wantProbeCount: 0,
+ },
+ {
+ name: "Unicast Neighbor Solicitation with source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ wantProbeCount: 1,
+ },
+ {
+ name: "Multicast Neighbor Solicitation without source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.SolicitedNodeAddr(lladdr0),
+ // "The source link-layer address option MUST be included in multicast
+ // solicitations." - RFC 4861 section 4.3
+ wantProbeCount: 0,
+ },
+ {
+ name: "Multicast Neighbor Solicitation with source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.SolicitedNodeAddr(lladdr0),
+ wantProbeCount: 1,
+ },
+ {
+ name: "Unicast Neighbor Advertisement without target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ // "When responding to unicast solicitations, the target link-layer
+ // address option can be omitted since the sender of the solicitation has
+ // the correct link-layer address; otherwise, it would not be able to
+ // send the unicast solicitation in the first place."
+ // - RFC 4861 section 4.4
+ wantConfirmationCount: 1,
+ },
+ {
+ name: "Unicast Neighbor Advertisement with target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ wantConfirmationCount: 1,
+ },
+ {
+ name: "Multicast Neighbor Advertisement without target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(false)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.IPv6AllNodesMulticastAddress,
+ // "Target link-layer address MUST be included for multicast solicitations
+ // in order to avoid infinite Neighbor Solicitation "recursion" when the
+ // peer node does not have a cache entry to return a Neighbor
+ // Advertisements message." - RFC 4861 section 4.4
+ wantConfirmationCount: 0,
+ },
+ {
+ name: "Multicast Neighbor Advertisement with target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(false)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.IPv6AllNodesMulticastAddress,
+ wantConfirmationCount: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: true,
+ })
+ {
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ nudHandler := &stubNUDHandler{}
+ ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ }
+ defer r.Release()
+
+ // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
+ r.LocalAddress = test.destination
+
+ icmp := test.createPacket()
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.RemoteAddress,
+ DstAddr: r.LocalAddress,
+ })
+ ep.HandlePacket(&r, pkt)
+
+ // Confirm the endpoint calls the correct NUDHandler method.
+ if nudHandler.probeCount != test.wantProbeCount {
+ t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount)
+ }
+ if nudHandler.confirmationCount != test.wantConfirmationCount {
+ t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index daf1fcbc6..1e38f3a9d 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,98 +12,373 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ipv6 contains the implementation of the ipv6 network protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.NewProtocol() as one of the network
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// ipv6.ProtocolNumber as the network protocol number when calling
-// Stack.NewEndpoint().
+// Package ipv6 contains the implementation of the ipv6 network protocol.
package ipv6
import (
+ "encoding/binary"
"fmt"
+ "hash/fnv"
+ "sort"
"sync/atomic"
+ "time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // As per RFC 8200 section 4.5:
+ // If insufficient fragments are received to complete reassembly of a packet
+ // within 60 seconds of the reception of the first-arriving fragment of that
+ // packet, reassembly of that packet must be abandoned.
+ //
+ // Linux also uses 60 seconds for reassembly timeout:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
+ ReassembleTimeout = 60 * time.Second
+
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
- // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // maxPayloadSize is the maximum size that can be encoded in the 16-bit
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
// DefaultTTL is the default hop limit for IPv6 Packets egressed by
// Netstack.
DefaultTTL = 64
+
+ // buckets for fragment identifiers
+ buckets = 2048
)
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+var _ stack.NDPEndpoint = (*endpoint)(nil)
+var _ NDPEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
- linkEP stack.LinkEndpoint
+ nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
dispatcher stack.TransportDispatcher
- fragmentation *fragmentation.Fragmentation
protocol *protocol
+ stack *stack.Stack
+
+ // enabled is set to 1 when the endpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ ndp ndpState
+ }
}
-// DefaultTTL is the default hop limit for this endpoint.
-func (e *endpoint) DefaultTTL() uint8 {
- return e.protocol.DefaultTTL()
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it is passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
-func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+// InvalidateDefaultRouter implements stack.NDPEndpoint.
+func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.invalidateDefaultRouter(rtr)
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+// SetNDPConfigurations implements NDPEndpoint.
+func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
+ c.validate()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.configs = c
}
-// ID returns the ipv6 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
+// hasTentativeAddr returns true if addr is tentative on e.
+func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
+ e.mu.RLock()
+ addressEndpoint := e.getAddressRLocked(addr)
+ e.mu.RUnlock()
+ return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative
}
-// PrefixLen returns the ipv6 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
+// dupTentativeAddrDetected attempts to inform e that a tentative addr is a
+// duplicate on a link.
+//
+// dupTentativeAddrDetected removes the tentative address if it exists. If the
+// address was generated via SLAAC, an attempt is made to generate a new
+// address.
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return tcpip.ErrBadAddress
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
+ // attempt will be made to generate a new address for it.
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ return err
+ }
+
+ prefix := addressEndpoint.AddressWithPrefix().Subnet()
+
+ switch t := addressEndpoint.ConfigType(); t {
+ case stack.AddressConfigStatic:
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.regenerateSLAACAddr(prefix)
+ case stack.AddressConfigSlaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ default:
+ panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ }
+
+ return nil
}
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.Enabled() {
+ return
+ }
+
+ if forwarding {
+ // When transitioning into an IPv6 router, host-only state (NDP discovered
+ // routers, discovered on-link prefixes, and auto-generated addresses) is
+ // cleaned up/invalidated and NDP router solicitations are stopped.
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(true /* hostOnly */)
+ } else {
+ // When transitioning into an IPv6 host, NDP router solicitations are
+ // started.
+ e.mu.ndp.startSolicitingRouters()
+ }
}
-// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
-// underlying protocols).
-func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ //
+ // Also auto-generate an IPv6 link-local address based on the endpoint's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the endpoint
+ // was last enabled, other devices may have acquired the same addresses.
+ var err *tcpip.Error
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if !header.IsV6UnicastAddress(addr) {
+ return true
+ }
+
+ switch addressEndpoint.GetKind() {
+ case stack.Permanent:
+ addressEndpoint.SetKind(stack.PermanentTentative)
+ fallthrough
+ case stack.PermanentTentative:
+ err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint)
+ return err == nil
+ default:
+ return true
+ }
+ })
+ if err != nil {
+ return err
+ }
+
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
+
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyway.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !e.protocol.Forwarding() {
+ e.mu.ndp.startSolicitingRouters()
+ }
+
+ return nil
}
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
+}
+
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.stopDADForPermanentAddressesLocked()
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
+ }
+}
+
+// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) stopDADForPermanentAddressesLocked() {
+ // Stop DAD for all the tentative unicast addresses.
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return true
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if header.IsV6UnicastAddress(addr) {
+ e.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+
+ return true
+ })
+}
+
+// DefaultTTL is the default hop limit for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ return 0
}
- return 0
+ return networkMTU
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 {
- length := uint16(hdr.UsedLength() + payloadSize)
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+ length := uint16(pkt.Size())
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(params.Protocol),
@@ -112,25 +387,97 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- return ip
+ pkt.NetworkProtocolNumber = ProtocolNumber
+}
+
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet. The transport header protocol number is required to avoid
+// parsing the IPv6 extension headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // maximum payload length because they should only be transmitted once.
+ fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7
+ if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit {
+ // We need at least 8 bytes of space left for the fragmentable part because
+ // the fragment payload must obviously be non-zero and must be a multiple
+ // of 8 as per RFC 8200 section 4.5:
+ // Each complete fragment, except possibly the last ("rightmost") one, is
+ // an integer multiple of 8 octets long.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) {
+ // As per RFC 8200 Section 4.5, the Transport Header is expected to be small
+ // enough to fit in the first fragment.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt))
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ return n, pf.RemainingFragmentCount(), nil
+ }
+ }
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
- ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
- pkt.NetworkHeader = buffer.View(ip)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt, params.Protocol)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
if r.Loop&stack.PacketLoop != 0 {
- // The inbound path expects the network header to still be in
- // the PacketBuffer's Data field.
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
+ e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // The inbound path expects an unparsed packet.
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
loopedR.Release()
}
@@ -138,11 +485,35 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
+ }
+
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+ return nil
}
-// WritePackets implements stack.LinkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
@@ -151,45 +522,163 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params)
- pb.NetworkHeader = buffer.View(ip)
+ e.addIPHeader(r, pb, params)
+
+ networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ if packetMustBeFragmented(pb, networkMTU, gso) {
+ // Keep track of the packet that is about to be fragmented so it can be
+ // removed once the fragmentation is done.
+ originalPkt := pb
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(pb, fragPkt)
+ pb = fragPkt
+ return nil
+ }); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ // Remove the packet that was just fragmented and process the rest.
+ pkts.Remove(originalPkt)
+ }
+ }
+
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
+ return n, err
+ }
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+
+ // Slow path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
+ }
+ n++
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
-// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
-// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
- // TODO(b/146666412): Support IPv6 header-included packets.
- return tcpip.ErrNotSupported
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required checks.
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
+ ip := header.IPv6(h)
+
+ // Always set the payload length.
+ pktSize := pkt.Data.Size()
+ ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == header.IPv6Any {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ proto, _, _, _, ok := parse.IPv6(pkt)
+ if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return e.writePacket(r, nil /* gso */, pkt, proto)
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
- if !ok {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
return
}
- h := header.IPv6(headerView)
- if !h.IsValid(pkt.Data.Size()) {
+
+ h := header.IPv6(pkt.NetworkHeader().View())
+ if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- pkt.NetworkHeader = headerView[:header.IPv6MinimumSize]
- pkt.Data.TrimFront(header.IPv6MinimumSize)
- pkt.Data.CapLength(int(h.PayloadLength()))
+ // As per RFC 4291 section 2.7:
+ // Multicast addresses must not be used as source addresses in IPv6
+ // packets or appear in any Routing header.
+ if header.IsV6MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data)
+ // vv consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView()
+ vv.AppendView(pkt.TransportHeader().View())
+ vv.Append(pkt.Data)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
- for firstHeader := true; ; firstHeader = false {
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and need not be forwarded.
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
+ return
+ }
+
+ for {
+ // Keep track of the start of the previous header so we can report the
+ // special case of a Hop by Hop at a location other than at the start.
+ previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -203,11 +692,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
- // (unrecognized next header) error in response to an extension header's
- // Next Header field with the Hop By Hop extension header identifier.
- if !firstHeader {
+ if previousHeaderStart != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ }, pkt)
return
}
@@ -229,13 +718,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
@@ -246,25 +747,27 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// As per RFC 8200 section 4.4, if a node encounters a routing header with
// an unrecognized routing type value, with a non-zero Segments Left
// value, the node must discard the packet and send an ICMP Parameter
- // Problem, Code 0. If the Segments Left is 0, the node must ignore the
- // Routing extension header and process the next header in the packet.
+ // Problem, Code 0 to the packet's Source Address, pointing to the
+ // unrecognized Routing Type.
+ //
+ // If the Segments Left is 0, the node must ignore the Routing extension
+ // header and process the next header in the packet.
//
// Note, the stack does not yet handle any type of routing extension
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
- // unrecognized routing types with a non-zero Segments Left value.
if extHdr.SegmentsLeft() != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
return
}
case header.IPv6FragmentExtHdr:
hasFragmentHeader = true
- fragmentOffset := extHdr.FragmentOffset()
- more := extHdr.More()
- if !more && fragmentOffset == 0 {
+ if extHdr.IsAtomic() {
// This fragment extension header indicates that this packet is an
// atomic fragment. An atomic fragment is a fragment that contains
// all the data required to reassemble a full packet. As per RFC 6946,
@@ -274,12 +777,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
continue
}
+ fragmentFieldOffset := it.ParseOffset()
+
// Don't consume the iterator if we have the first fragment because we
// will use it to validate that the first fragment holds the upper layer
// header.
- rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */)
+ rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */)
- if fragmentOffset == 0 {
+ if extHdr.FragmentOffset() == 0 {
// Check that the iterator ends with a raw payload as the first fragment
// should include all headers up to and including any upper layer
// headers, as per RFC 8200 section 4.5; only upper layer data
@@ -290,7 +795,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
it, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -331,32 +836,89 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
return
}
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length of a fragment, as derived from the fragment packet's
+ // Payload Length field, is not a multiple of 8 octets and the M flag
+ // of that fragment is 1, then that fragment must be discarded and an
+ // ICMP Parameter Problem, Code 0, message should be sent to the source
+ // of the fragment, pointing to the Payload Length field of the
+ // fragment packet.
+ if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6PayloadLenOffset,
+ }, pkt)
+ return
+ }
+
// The packet is a fragment, let's try to reassemble it.
- start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- last := start + uint16(fragmentPayloadLen) - 1
+ start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < start {
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length and offset of a fragment are such that the Payload
+ // Length of the packet reassembled from that fragment would exceed
+ // 65,535 octets, then that fragment must be discarded and an ICMP
+ // Parameter Problem, Code 0, message should be sent to the source of
+ // the fragment, pointing to the Fragment Offset field of the fragment
+ // packet.
+ if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: fragmentFieldOffset,
+ }, pkt)
return
}
- var ready bool
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf)
+ // Set up a callback in case we need to send a Time Exceeded Message as
+ // per RFC 2460 Section 4.5.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ r := r.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt)
+ }
+ r.Release()
+ }
+ }
+
+ // Note that pkt doesn't have its transport header set after reassembly,
+ // and won't until DeliverNetworkPacket sets it.
+ data, proto, ready, err := e.protocol.fragmentation.Process(
+ // IPv6 ignores the Protocol field since the ID only needs to be unique
+ // across source-destination pairs, as per RFC 8200 section 4.5.
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: extHdr.ID(),
+ },
+ start,
+ start+uint16(fragmentPayloadLen)-1,
+ extHdr.More(),
+ uint8(rawPayload.Identifier),
+ rawPayload.Buf,
+ releaseCB,
+ )
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
+ pkt.Data = data
if ready {
// We create a new iterator with the reassembled packet because we could
// have more extension headers in the reassembled payload, as per RFC
- // 8200 section 4.5.
- it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ // 8200 section 4.5. We also use the NextHeader value from the first
+ // fragment.
+ it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data)
}
case header.IPv6DestinationOptionsExtHdr:
@@ -378,13 +940,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
@@ -394,23 +968,64 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
case header.IPv6RawPayloadHeader:
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
+
+ // For unfragmented packets, extHdr still contains the transport header.
+ // Get rid of it.
+ //
+ // For reassembled fragments, pkt.TransportHeader is unset, so this is a
+ // no-op and pkt.Data begins with the transport header.
+ extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
+ r.Stats().IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, pkt, hasFragmentHeader)
+ pkt.TransportProtocolNumber = p
+ e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ case stack.TransportPacketHandled:
+ case stack.TransportPacketDestinationPortUnreachable:
+ // As per RFC 4443 section 3.1:
+ // A destination node SHOULD originate a Destination Unreachable
+ // message with Code 4 in response to a packet for which the
+ // transport protocol (e.g., UDP) has no listener, if that transport
+ // protocol has no alternative means to inform the sender.
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC 8200 section 4. (page 7):
+ // Extension headers are numbered from IANA IP Protocol Numbers
+ // [IANA-PN], the same values used for IPv4 and IPv6. When
+ // processing a sequence of Next Header values in a packet, the
+ // first one that is not an extension header [IANA-EH] indicates
+ // that the next item in the packet is the corresponding upper-layer
+ // header.
+ // With more related information on page 8:
+ // If, as a result of processing a header, the destination node is
+ // required to proceed to the next header but the Next Header value
+ // in the current header is unrecognized by the node, it should
+ // discard the packet and send an ICMP Parameter Problem message to
+ // the source of the packet, with an ICMP Code value of 1
+ // ("unrecognized Next Header type encountered") and the ICMP
+ // Pointer field containing the offset of the unrecognized value
+ // within the original packet.
+ //
+ // Which when taken together indicate that an unknown protocol should
+ // be treated as an unrecognized next header value.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
+ }
}
default:
- // If we receive a packet for an extension header we do not yet handle,
- // drop the packet for now.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
r.Stats().UnknownProtocolRcvdPackets.Increment()
return
}
@@ -418,18 +1033,343 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
}
// Close cleans up resources associated with the endpoint.
-func (*endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.disableLocked()
+ e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */)
+ e.stopDADForPermanentAddressesLocked()
+ e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
+
+ e.protocol.forgetEndpoint(e)
+}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ // TODO(b/169350103): add checks here after making sure we no longer receive
+ // an empty address.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+}
+
+// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
+// with locking requirements.
+//
+// addAndAcquirePermanentAddressLocked also joins the passed address's
+// solicited-node multicast group and start duplicate address detection.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err != nil {
+ return nil, err
+ }
+
+ if !header.IsV6UnicastAddress(addr.Address) {
+ return addressEndpoint, nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
+ return nil, err
+ }
+
+ addressEndpoint.SetKind(stack.PermanentTentative)
+
+ if e.Enabled() {
+ if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil {
+ return nil, err
+ }
+ }
+
+ return addressEndpoint, nil
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return e.removePermanentEndpointLocked(addressEndpoint, true)
+}
+
+// removePermanentEndpointLocked is like removePermanentAddressLocked except
+// it works with a stack.AddressEndpoint.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+ addr := addressEndpoint.AddressWithPrefix()
+ unicast := header.IsV6UnicastAddress(addr.Address)
+ if unicast {
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ switch addressEndpoint.ConfigType() {
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case stack.AddressConfigSlaacTemp:
+ e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ }
+ }
+
+ if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil {
+ return err
+ }
+
+ if !unicast {
+ return nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
+ return nil
+}
+
+// hasPermanentAddressLocked returns true if the endpoint has a permanent
+// address equal to the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return false
+ }
+ return addressEndpoint.GetKind().IsPermanent()
+}
+
+// getAddressRLocked returns the endpoint for the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
+}
+
+// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with
+// locking requirements.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB)
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ // addrCandidate is a candidate for Source Address Selection, as per
+ // RFC 6724 section 5.
+ type addrCandidate struct {
+ addressEndpoint stack.AddressEndpoint
+ scope header.IPv6AddressScope
+ }
+
+ if len(remoteAddr) == 0 {
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ var cs []addrCandidate
+ e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !addressEndpoint.IsAssigned(allowExpired) {
+ return
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
+ }
+
+ cs = append(cs, addrCandidate{
+ addressEndpoint: addressEndpoint,
+ scope: scope,
+ })
+ })
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return true
+ }
+ if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
+ if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
+ return saTemp
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if c.addressEndpoint.IncRef() {
+ return c.addressEndpoint
+ }
+ }
+
+ return nil
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV6MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
+
type protocol struct {
+ stack *stack.Stack
+
+ mu struct {
+ sync.RWMutex
+
+ eps map[*endpoint]struct{}
+ }
+
+ ids []uint32
+ hashIV uint32
+
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
defaultTTL uint32
+
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
+ fragmentation *fragmentation.Fragmentation
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
+ ndpConfigs NDPConfigurations
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
+ // autoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -454,24 +1394,42 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- return &endpoint{
- nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
linkAddrCache: linkAddrCache,
+ nud: nud,
dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
- }, nil
+ }
+ e.mu.addressableEndpointState.Init(e)
+ e.mu.ndp = ndpState{
+ ep: e,
+ configs: p.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+ }
+ e.mu.ndp.initializeTempAddrState()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.mu.eps[e] = struct{}{}
+ return e
+}
+
+func (p *protocol) forgetEndpoint(e *endpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, e)
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case tcpip.DefaultTTLOption:
- p.SetDefaultTTL(uint8(v))
+ case *tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(*v))
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -479,7 +1437,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
@@ -505,17 +1463,193 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- mtu -= header.IPv6MinimumSize
- if mtu <= maxPayloadSize {
- return mtu
+// Parse implements stack.NetworkProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
+ if !ok {
+ return 0, false, false
}
- return maxPayloadSize
+
+ return proto, !fragMore && fragOffset == 0, true
}
-// NewProtocol returns an IPv6 network protocol.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{defaultTTL: DefaultTTL}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
+
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&p.forwarding, 1) == 0
+ }
+ return atomic.SwapUint32(&p.forwarding, 0) == 1
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if !p.setForwarding(v) {
+ return
+ }
+
+ for ep := range p.mu.eps {
+ ep.transitionForwarding(v)
+ }
+}
+
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload MTU and the length of every IPv6 header.
+// Note that this is different than the Payload Length field of the IPv6 header,
+// which includes the length of the extension headers.
+func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv6MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ // As per RFC 7112 section 5, we should discard packets if their IPv6 header
+ // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not
+ // support PMTU discovery:
+ // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain
+ // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280
+ // bytes ensures that the header chain length does not exceed the IPv6
+ // minimum MTU.
+ if networkHeadersLen > header.IPv6MinimumMTU {
+ return 0, tcpip.ErrMalformedHeader
+ }
+
+ networkMTU := linkMTU - uint32(networkHeadersLen)
+ if networkMTU > maxPayloadSize {
+ networkMTU = maxPayloadSize
+ }
+ return networkMTU, nil
+}
+
+// Options holds options to configure a new protocol.
+type Options struct {
+ // NDPConfigs is the default NDP configurations used by interfaces.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
+ // Note, setting this to true does not mean that a link-local address is
+ // assigned right away, or at all. If Duplicate Address Detection is enabled,
+ // an address is only assigned if it successfully resolves. If it fails, no
+ // further attempts are made to auto-generate an IPv6 link-local adddress.
+ //
+ // The generated link-local address follows RFC 4291 Appendix A guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
+}
+
+// NewProtocolWithOptions returns an IPv6 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
+ opts.NDPConfigs.validate()
+
+ ids := hash.RandN32(buckets)
+ hashIV := hash.RandN32(1)[0]
+
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
+ ids: ids,
+ hashIV: hashIV,
+
+ ndpDisp: opts.NDPDisp,
+ ndpConfigs: opts.NDPConfigs,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ }
+ p.mu.eps = make(map[*endpoint]struct{})
+ p.SetDefaultTTL(DefaultTTL)
+ return p
+ }
+}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
+}
+
+func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
+ return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address and 32-bit number to generate the hash.
+func hashRoute(r *stack.Route, hashIV uint32) uint32 {
+ // The FNV-1a was chosen because it is a fast hashing algorithm, and
+ // cryptographic properties are not needed here.
+ h := fnv.New32a()
+ if _, err := h.Write([]byte(r.LocalAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ if _, err := h.Write([]byte(r.RemoteAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ s := make([]byte, 4)
+ binary.LittleEndian.PutUint32(s, hashIV)
+ if _, err := h.Write(s); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err))
+ }
+
+ return h.Sum32()
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeadersLength := len(originalIPHeaders)
+ fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+ fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+
+ // Copy the IPv6 header and any extension headers already populated.
+ if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
+ }
+ fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
+
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
+ fragmentHeader.Encode(&header.IPv6FragmentFields{
+ M: more,
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ Identification: id,
+ NextHeader: uint8(transportProto),
+ })
+
+ return fragPkt, more
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 841a0cb7a..c593c0004 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,15 +15,22 @@
package ipv6
import (
+ "encoding/hex"
+ "fmt"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -43,6 +50,8 @@ const (
fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+
+ extraHeaderReserve = 50
)
// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
@@ -51,8 +60,8 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
t.Helper()
// Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
@@ -65,9 +74,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
stats := s.Stats().ICMP.V6PacketsReceived
@@ -123,9 +132,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
stat := s.Stats().UDP.PacketsReceived
@@ -134,25 +143,103 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
}
}
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // sourcePacket does not have its IP Header populated. Let's copy the one
+ // from the first fragment.
+ source := header.IPv6(packets[0].NetworkHeader().View())
+ sourceIPHeadersLen := len(source)
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
+ source = append(source, vv.ToView()...)
+
+ var reassembledPayload buffer.VectorisedView
+ for i, fragment := range packets {
+ // Confirm that the packet is valid.
+ allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
+ fragmentIPHeaders := header.IPv6(allBytes.ToView())
+ if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
+ }
+
+ fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
+ if fragmentIPHeadersLength != sourceIPHeadersLen {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
+ }
+
+ if got := len(fragmentIPHeaders); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
+ }
+
+ sourceIPHeader := source[:header.IPv6MinimumSize]
+ fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
+
+ if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
+ }
+
+ // We expect the IPv6 Header to be similar across each fragment, besides the
+ // payload length.
+ sourceIPHeader.SetPayloadLength(0)
+ fragmentIPHeader.SetPayloadLength(0)
+ if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
+ }
+ if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
+ }
+
+ if len(packets) > 1 {
+ // If the source packet was big enough that it needed fragmentation, let's
+ // inspect the fragment header. Because no other extension headers are
+ // supported, it will always be the last extension header.
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
+
+ if got := fragmentHeader.More(); got != wantFragments[i].more {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
+ }
+ if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
+ }
+ if got := fragmentHeader.NextHeader(); got != uint8(proto) {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
+ }
+ }
+
+ // Store the reassembled payload as we parse each fragment. The payload
+ // includes the Transport header and everything after.
+ reassembledPayload.AppendView(fragment.TransportHeader().View())
+ reassembledPayload.Append(fragment.Data)
+ }
+
+ if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+
+ return nil
+}
+
// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
tests := []struct {
name string
- protocolFactory stack.TransportProtocol
+ protocolFactory stack.TransportProtocolFactory
rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6, testReceiveICMP},
+ {"UDP", udp.NewProtocol, testReceiveUDP},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(10, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
@@ -168,15 +255,13 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
- protocolFactory stack.TransportProtocol
+ protocolFactory stack.TransportProtocolFactory
rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6, testReceiveICMP},
+ {"UDP", udp.NewProtocol, testReceiveUDP},
}
snmc := header.SolicitedNodeAddr(addr2)
@@ -184,16 +269,16 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(1, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
@@ -271,7 +356,7 @@ func TestAddIpv6Address(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -293,17 +378,22 @@ func TestAddIpv6Address(t *testing.T) {
}
func TestReceiveIPv6ExtHdrs(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
shouldAccept bool
+ // Should we expect an ICMP response and if so, with what contents?
+ expectICMP bool
+ ICMPType header.ICMPv6Type
+ ICMPCode header.ICMPv6Code
+ pointer uint32
+ multicast bool
}{
{
name: "None",
extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
shouldAccept: true,
+ expectICMP: false,
},
{
name: "hopbyhop with unknown option skippable action",
@@ -334,9 +424,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action",
+ name: "hopbyhop with unknown option discard and send icmp action (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -346,12 +457,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -362,39 +499,77 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: true,
},
{
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 1, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6FixedHeaderSize + 2,
},
{
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 0, 0, 0, 0,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
},
{
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
+ expectICMP: false,
},
{
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{},
+ noNextHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
name: "destination with unknown option skippable action",
@@ -410,6 +585,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: true,
+ expectICMP: false,
},
{
name: "destination with unknown option discard action",
@@ -425,9 +601,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action",
+ name: "destination with unknown option discard and send icmp action (muilticast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -437,12 +634,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
}, destinationExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action unless multicast dest",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -453,22 +656,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "routing - atomic fragment",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ nextHdr, 1,
- // Fragment extension header.
- nextHdr, 0, 0, 0, 1, 2, 3, 4,
- }, routingExtHdrID
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
+ }, destinationExtHdrID
},
- shouldAccept: true,
+ shouldAccept: false,
+ expectICMP: false,
+ multicast: true,
},
{
name: "atomic fragment - routing",
@@ -502,12 +716,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
return []byte{
// Routing extension header.
hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
// Hop By Hop extension header with skippable unknown option.
nextHdr, 0, 62, 4, 1, 2, 3, 4,
}, routingExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
+ name: "routing - hop by hop (with send icmp unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
+
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
},
{
name: "No next header",
@@ -551,6 +795,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
@@ -571,16 +816,17 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -588,6 +834,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ // Add a default route so that a return packet knows where to go.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -629,17 +883,21 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
HopLimit: 255,
SrcAddr: addr1,
- DstAddr: addr2,
+ DstAddr: dstAddr,
})
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
stats := s.Stats().UDP.PacketsReceived
@@ -648,6 +906,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 0", got)
}
+ if !test.expectICMP {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpected packet received: %#v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected packet wasn't written out")
+ }
+
+ // Pack the output packet into a single buffer.View as the checkers
+ // assume that.
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
+ t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
+ }
+
+ ipHdr := header.IPv6(pkt)
+ checker.IPv6(t, ipHdr, checker.ICMPv6(
+ checker.ICMPv6Type(test.ICMPType),
+ checker.ICMPv6Code(test.ICMPCode)))
+
+ // We know we are looking at no extension headers in the error ICMP
+ // packets.
+ icmpPkt := header.ICMPv6(ipHdr.Payload())
+ // We know we sent small packets that won't be truncated when reflected
+ // back to us.
+ originalPacket := icmpPkt.Payload()
+ if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
+ t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
+ }
+ if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
+ t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
+ }
return
}
@@ -673,20 +969,27 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
nextHdr uint8
data buffer.VectorisedView
}
func TestReceiveIPv6Fragments(t *testing.T) {
- const nicID = 1
- const udpPayload1Length = 256
- const udpPayload2Length = 128
- const fragmentExtHdrLen = 8
- // Note, not all routing extension headers will be 8 bytes but this test
- // uses 8 byte routing extension headers for most sub tests.
- const routingExtHdrLen = 8
-
- udpGen := func(payload []byte, multiplier uint8) buffer.View {
+ const (
+ udpPayload1Length = 256
+ udpPayload2Length = 128
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 8200 section 4.5).
+ udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
+ fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ routingExtHdrLen = 8
+ )
+
+ udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View {
payloadLen := len(payload)
for i := 0; i < payloadLen; i++ {
payload[i] = uint8(i) * multiplier
@@ -702,19 +1005,31 @@ func TestReceiveIPv6Fragments(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
sum = header.Checksum(payload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
return hdr.View()
}
- var udpPayload1Buf [udpPayload1Length]byte
- udpPayload1 := udpPayload1Buf[:]
- ipv6Payload1 := udpGen(udpPayload1, 1)
+ var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:]
+ ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2)
+
+ var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:]
+ ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2)
- var udpPayload2Buf [udpPayload2Length]byte
- udpPayload2 := udpPayload2Buf[:]
- ipv6Payload2 := udpGen(udpPayload2, 2)
+ var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte
+ udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:]
+ ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2)
+
+ var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte
+ udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
+ ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
+
+ var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte
+ udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:]
+ ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2)
tests := []struct {
name string
@@ -726,34 +1041,60 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "No fragmentation",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: uint8(header.UDPProtocolNumber),
- data: ipv6Payload1.ToVectorisedView(),
+ data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Atomic fragment",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1Addr1ToAddr2,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Atomic fragment with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1),
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
- ipv6Payload1,
+ ipv6Payload3Addr1ToAddr2,
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
},
{
name: "Two fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -763,31 +1104,189 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments out of order",
+ fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with different Next Header values",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ // NextHeader value is different than the one in the first fragment, so
+ // this NextHeader should be ignored.
+ buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[:63],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[63:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
},
{
name: "Two fragments with different IDs",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -797,21 +1296,23 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -819,9 +1320,49 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: nil,
},
{
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:65520],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[65520:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -836,14 +1377,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Routing extension header.
//
@@ -855,17 +1398,19 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Two fragments with per-fragment routing header with non-zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -880,14 +1425,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Routing extension header.
//
@@ -899,7 +1446,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -910,6 +1457,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -924,31 +1473,35 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Segments left = 0.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Two fragments with routing header with non-zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -963,21 +1516,23 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Segments left = 1.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -988,6 +1543,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with zero segments left across fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is fragmentExtHdrLen+8 because the
@@ -1008,12 +1565,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
// the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1),
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
@@ -1023,7 +1582,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header (part 2)
buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
- ipv6Payload1,
+ ipv6Payload1Addr1ToAddr2,
},
),
},
@@ -1034,6 +1593,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with non-zero segments left across fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is fragmentExtHdrLen+8 because the
@@ -1054,12 +1615,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
// the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1),
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
@@ -1069,7 +1632,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header (part 2)
buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
- ipv6Payload1,
+ ipv6Payload1Addr1ToAddr2,
},
),
},
@@ -1082,6 +1645,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with atomic",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -1091,47 +1656,53 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
// This fragment has the same ID as the other fragments but is an atomic
// fragment. It should not interfere with the other fragments.
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2),
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
- ipv6Payload2,
+ ipv6Payload2Addr1ToAddr2,
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload2, udpPayload1},
+ expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2},
},
{
name: "Two interleaved fragmented packets",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -1141,11 +1712,13 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+32,
@@ -1155,50 +1728,124 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
- ipv6Payload2[:32],
+ ipv6Payload2Addr1ToAddr2[:32],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2)-32,
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 4, More = false, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
- ipv6Payload2[32:],
+ ipv6Payload2Addr1ToAddr2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[:32],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[32:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -1231,16 +1878,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
PayloadLength: uint16(f.data.Size()),
NextHeader: f.nextHdr,
HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: addr2,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
vv.Append(f.data)
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- })
+ }))
}
if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
@@ -1263,3 +1910,920 @@ func TestReceiveIPv6Fragments(t *testing.T) {
})
}
}
+
+func TestInvalidIPv6Fragments(t *testing.T) {
+ const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_INVALID_IPV6_FRAGMENTS"
+ )
+
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ expectICMP bool
+ expectICMPType header.ICMPv6Type
+ expectICMPCode header.ICMPv6Code
+ expectICMPTypeSpecific uint32
+ }{
+ {
+ name: "fragment size is not a multiple of 8 and the M flag is true",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0 >> 3,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:9],
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6PayloadLenOffset,
+ },
+ {
+ name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ NewProtocol,
+ },
+ })
+ e := channel.New(1, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var expectICMPPayload buffer.View
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if test.expectICMP {
+ expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
+ t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
+ t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
+ }
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(test.expectICMPType),
+ checker.ICMPv6Code(test.expectICMPCode),
+ checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
+ checker.ICMPv6Payload([]byte(expectICMPPayload)),
+ ),
+ )
+ })
+ }
+}
+
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
+ {
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: false,
+ },
+ {
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ NewProtocol,
+ },
+ Clock: clock,
+ })
+
+ e := channel.New(1, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ clock.Advance(ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
+ checker.ICMPv6Payload([]byte(firstFragmentSent)),
+ ),
+ )
+ })
+ }
+}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
+ const (
+ src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ )
+ if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
+ }
+ {
+ mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
+ subnet, err := tcpip.NewSubnet(dst, mask)
+ if err != nil {
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
+ }
+ return rt
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if !hasEP {
+ t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep)
+ }
+ }
+
+ ep.Close()
+
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep)
+ }
+ }
+}
+
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
+
+var fragmentationTests = []struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transHdrLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+}{
+ {
+ description: "No fragmentation",
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
+ transHdrLen: 0,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
+ transHdrLen: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv6MinimumMTU + 1,
+ gso: nil,
+ transHdrLen: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: nil,
+ transHdrLen: 100,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso none",
+ mtu: header.IPv6MinimumMTU,
+ gso: &stack.GSO{Type: stack.GSONone},
+ transHdrLen: 0,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 176, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
+ transHdrLen: 100,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 76, more: false},
+ },
+ },
+}
+
+func TestFragmentationWritePacket(t *testing.T) {
+ const (
+ ttl = 42
+ tos = stack.DefaultTOS
+ transportProto = tcp.ProtocolNumber
+ )
+
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt.Clone()
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != nil {
+ t.Fatalf("WritePacket(_, _, _): = %s", err)
+ }
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+}
+
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ tests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if n != wantTotalPackets || err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from WritePacket
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ tests := []struct {
+ description string
+ mtu uint32
+ transHdrLen int
+ payloadSize int
+ allowPackets int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
+ }{
+ {
+ description: "No frag",
+ mtu: 2000,
+ payloadSize: 1000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 1300,
+ payloadSize: 3000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 1500,
+ payloadSize: 4000,
+ transHdrLen: 0,
+ allowPackets: 1,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error when MTU is smaller than transport header",
+ mtu: header.IPv6MinimumMTU,
+ transHdrLen: 1500,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrMessageTooLong,
+ },
+ {
+ description: "Error when MTU is smaller than IPv6 minimum MTU",
+ mtu: header.IPv6MinimumMTU - 1,
+ transHdrLen: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
+ },
+ }
+
+ for _, ft := range tests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
new file mode 100644
index 000000000..40da011f8
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -0,0 +1,2013 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "log"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
+ // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
+ // Solicitation messages to send when doing Duplicate Address Detection
+ // for a tentative address.
+ //
+ // Default = 1 (from RFC 4862 section 5.1)
+ defaultDupAddrDetectTransmits = 1
+
+ // defaultMaxRtrSolicitations is the default number of Router
+ // Solicitation messages to send when an IPv6 endpoint becomes enabled.
+ //
+ // Default = 3 (from RFC 4861 section 10).
+ defaultMaxRtrSolicitations = 3
+
+ // defaultRtrSolicitationInterval is the default amount of time between
+ // sending Router Solicitation messages.
+ //
+ // Default = 4s (from 4861 section 10).
+ defaultRtrSolicitationInterval = 4 * time.Second
+
+ // defaultMaxRtrSolicitationDelay is the default maximum amount of time
+ // to wait before sending the first Router Solicitation message.
+ //
+ // Default = 1s (from 4861 section 10).
+ defaultMaxRtrSolicitationDelay = time.Second
+
+ // defaultHandleRAs is the default configuration for whether or not to
+ // handle incoming Router Advertisements as a host.
+ defaultHandleRAs = true
+
+ // defaultDiscoverDefaultRouters is the default configuration for
+ // whether or not to discover default routers from incoming Router
+ // Advertisements, as a host.
+ defaultDiscoverDefaultRouters = true
+
+ // defaultDiscoverOnLinkPrefixes is the default configuration for
+ // whether or not to discover on-link prefixes from incoming Router
+ // Advertisements' Prefix Information option, as a host.
+ defaultDiscoverOnLinkPrefixes = true
+
+ // defaultAutoGenGlobalAddresses is the default configuration for
+ // whether or not to generate global IPv6 addresses in response to
+ // receiving a new Prefix Information option with its Autonomous
+ // Address AutoConfiguration flag set, as a host.
+ //
+ // Default = true.
+ defaultAutoGenGlobalAddresses = true
+
+ // minimumRtrSolicitationInterval is the minimum amount of time to wait
+ // between sending Router Solicitation messages. This limit is imposed
+ // to make sure that Router Solicitation messages are not sent all at
+ // once, defeating the purpose of sending the initial few messages.
+ minimumRtrSolicitationInterval = 500 * time.Millisecond
+
+ // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait
+ // before sending the first Router Solicitation message. It is 0 because
+ // we cannot have a negative delay.
+ minimumMaxRtrSolicitationDelay = 0
+
+ // MaxDiscoveredDefaultRouters is the maximum number of discovered
+ // default routers. The stack should stop discovering new routers after
+ // discovering MaxDiscoveredDefaultRouters routers.
+ //
+ // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
+ // SHOULD be more.
+ MaxDiscoveredDefaultRouters = 10
+
+ // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
+ // on-link prefixes. The stack should stop discovering new on-link
+ // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
+ // prefixes.
+ MaxDiscoveredOnLinkPrefixes = 10
+
+ // validPrefixLenForAutoGen is the expected prefix length that an
+ // address can be generated for. Must be 64 bits as the interface
+ // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so
+ // 128 - 64 = 64.
+ validPrefixLenForAutoGen = 64
+
+ // defaultAutoGenTempGlobalAddresses is the default configuration for whether
+ // or not to generate temporary SLAAC addresses.
+ defaultAutoGenTempGlobalAddresses = true
+
+ // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 7 days (from RFC 4941 section 5).
+ defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour
+
+ // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 1 day (from RFC 4941 section 5).
+ defaultMaxTempAddrPreferredLifetime = 24 * time.Hour
+
+ // defaultRegenAdvanceDuration is the default duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ //
+ // Default = 5s (from RFC 4941 section 5).
+ defaultRegenAdvanceDuration = 5 * time.Second
+
+ // minRegenAdvanceDuration is the minimum duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ minRegenAdvanceDuration = time.Duration(0)
+
+ // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
+ // SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
+ maxSLAACAddrLocalRegenAttempts = 10
+)
+
+var (
+ // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
+ // Lifetime to update the valid lifetime of a generated address by
+ // SLAAC.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Min = 2hrs.
+ MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
+
+ // MaxDesyncFactor is the upper bound for the preferred lifetime's desync
+ // factor for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Must be greater than 0.
+ //
+ // Max = 10m (from RFC 4941 section 5).
+ MaxDesyncFactor = 10 * time.Minute
+
+ // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
+ // maximum preferred lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address is preferred for at
+ // least 1hr if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
+
+ // MinMaxTempAddrValidLifetime is the minimum value allowed for the
+ // maximum valid lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address is valid for at least
+ // 2hrs if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrValidLifetime = 2 * time.Hour
+)
+
+// NDPEndpoint is an endpoint that supports NDP.
+type NDPEndpoint interface {
+ // SetNDPConfigurations sets the NDP configurations.
+ SetNDPConfigurations(NDPConfigurations)
+}
+
+// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
+// NDP Router Advertisement informed the Stack about.
+type DHCPv6ConfigurationFromNDPRA int
+
+const (
+ _ DHCPv6ConfigurationFromNDPRA = iota
+
+ // DHCPv6NoConfiguration indicates that no configurations are available via
+ // DHCPv6.
+ DHCPv6NoConfiguration
+
+ // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
+ //
+ // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
+ // returns all available configuration information when serving addresses.
+ DHCPv6ManagedAddress
+
+ // DHCPv6OtherConfigurations indicates that other configuration information is
+ // available via DHCPv6.
+ //
+ // Other configurations are configurations other than addresses. Examples of
+ // other configurations are recursive DNS server list, DNS search lists and
+ // default gateway.
+ DHCPv6OtherConfigurations
+)
+
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus is called when the DAD process for an
+ // address (addr) on a NIC (with ID nicID) completes. resolved is set to true
+ // if DAD completed successfully (no duplicate addr detected); false otherwise
+ // (addr was detected to be a duplicate on the link the NIC is a part of, or
+ // it was stopped for some other reason, such as the address being removed).
+ // If an error occured during DAD, err is set and resolved must be ignored.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+
+ // OnDefaultRouterDiscovered is called when a new default router is
+ // discovered. Implementations must return true if the newly discovered
+ // router should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
+
+ // OnDefaultRouterInvalidated is called when a discovered default router that
+ // was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
+
+ // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
+ // Implementations must return true if the newly discovered on-link prefix
+ // should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
+
+ // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
+ // was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
+
+ // OnAutoGenAddress is called when a new prefix with its autonomous address-
+ // configuration flag set is received and SLAAC was performed. Implementations
+ // may prevent the stack from assigning the address to the NIC by returning
+ // false.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+
+ // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
+ // is deprecated, but is still considered valid. Note, if an address is
+ // invalidated at the same ime it is deprecated, the deprecation event may not
+ // be received.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnAutoGenAddressInvalidated is called when an auto-generated address
+ // (SLAAC) is invalidated.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnRecursiveDNSServerOption is called when the stack learns of DNS servers
+ // through NDP. Note, the addresses may contain link-local addresses.
+ //
+ // It is up to the caller to use the DNS Servers only for their valid
+ // lifetime. OnRecursiveDNSServerOption may be called for new or
+ // already known DNS servers. If called with known DNS servers, their
+ // valid lifetimes must be refreshed to lifetime (it may be increased,
+ // decreased, or completely invalidated when lifetime = 0).
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+
+ // OnDNSSearchListOption is called when the stack learns of DNS search lists
+ // through NDP.
+ //
+ // It is up to the caller to use the domain names in the search list
+ // for only their valid lifetime. OnDNSSearchListOption may be called
+ // with new or already known domain names. If called with known domain
+ // names, their valid lifetimes must be refreshed to lifetime (it may
+ // be increased, decreased or completely invalidated when lifetime = 0.
+ OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+
+ // OnDHCPv6Configuration is called with an updated configuration that is
+ // available via DHCPv6 for the passed NIC.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
+}
+
+// NDPConfigurations is the NDP configurations for the netstack.
+type NDPConfigurations struct {
+ // The number of Neighbor Solicitation messages to send when doing
+ // Duplicate Address Detection for a tentative address.
+ //
+ // Note, a value of zero effectively disables DAD.
+ DupAddrDetectTransmits uint8
+
+ // The amount of time to wait between sending Neighbor solicitation
+ // messages.
+ //
+ // Must be greater than or equal to 1ms.
+ RetransmitTimer time.Duration
+
+ // The number of Router Solicitation messages to send when the IPv6 endpoint
+ // becomes enabled.
+ MaxRtrSolicitations uint8
+
+ // The amount of time between transmitting Router Solicitation messages.
+ //
+ // Must be greater than or equal to 0.5s.
+ RtrSolicitationInterval time.Duration
+
+ // The maximum amount of time before transmitting the first Router
+ // Solicitation message.
+ //
+ // Must be greater than or equal to 0s.
+ MaxRtrSolicitationDelay time.Duration
+
+ // HandleRAs determines whether or not Router Advertisements are processed.
+ HandleRAs bool
+
+ // DiscoverDefaultRouters determines whether or not default routers are
+ // discovered from Router Advertisements, as per RFC 4861 section 6. This
+ // configuration is ignored if HandleRAs is false.
+ DiscoverDefaultRouters bool
+
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
+ // discovered from Router Advertisements' Prefix Information option, as per
+ // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
+ DiscoverOnLinkPrefixes bool
+
+ // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
+ // SLAAC to auto-generate global SLAAC addresses in response to Prefix
+ // Information options, as per RFC 4862.
+ //
+ // Note, if an address was already generated for some unique prefix, as
+ // part of SLAAC, this option does not affect whether or not the
+ // lifetime(s) of the generated address changes; this option only
+ // affects the generation of new addresses as part of SLAAC.
+ AutoGenGlobalAddresses bool
+
+ // AutoGenAddressConflictRetries determines how many times to attempt to retry
+ // generation of a permanent auto-generated address in response to DAD
+ // conflicts.
+ //
+ // If the method used to generate the address does not support creating
+ // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
+ // MAC address), then no attempt is made to resolve the conflict.
+ AutoGenAddressConflictRetries uint8
+
+ // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
+ // addresses are generated for an IPv6 endpoint as part of SLAAC privacy
+ // extensions, as per RFC 4941.
+ //
+ // Ignored if AutoGenGlobalAddresses is false.
+ AutoGenTempGlobalAddresses bool
+
+ // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary
+ // SLAAC addresses.
+ MaxTempAddrValidLifetime time.Duration
+
+ // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for
+ // temporary SLAAC addresses.
+ MaxTempAddrPreferredLifetime time.Duration
+
+ // RegenAdvanceDuration is the duration before the deprecation of a temporary
+ // address when a new address will be generated.
+ RegenAdvanceDuration time.Duration
+}
+
+// DefaultNDPConfigurations returns an NDPConfigurations populated with
+// default values.
+func DefaultNDPConfigurations() NDPConfigurations {
+ return NDPConfigurations{
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
+ MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime,
+ RegenAdvanceDuration: defaultRegenAdvanceDuration,
+ }
+}
+
+// validate modifies an NDPConfigurations with valid values. If invalid values
+// are present in c, the corresponding default values are used instead.
+func (c *NDPConfigurations) validate() {
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+
+ if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
+ c.RtrSolicitationInterval = defaultRtrSolicitationInterval
+ }
+
+ if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
+ c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
+ }
+
+ if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime {
+ c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime
+ }
+
+ if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime {
+ c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime
+ }
+
+ if c.RegenAdvanceDuration < minRegenAdvanceDuration {
+ c.RegenAdvanceDuration = minRegenAdvanceDuration
+ }
+}
+
+// ndpState is the per-interface NDP state.
+type ndpState struct {
+ // The IPv6 endpoint this ndpState is for.
+ ep *endpoint
+
+ // configs is the per-interface NDP configurations.
+ configs NDPConfigurations
+
+ // The DAD state to send the next NS message, or resolve the address.
+ dad map[tcpip.Address]dadState
+
+ // The default routers discovered through Router Advertisements.
+ defaultRouters map[tcpip.Address]defaultRouterState
+
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer tcpip.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the IPv6 endpoint this ndpState is associated with. MUST be set when the
+ // timer is set.
+ done *bool
+ }
+
+ // The on-link prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+
+ // The SLAAC prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ slaacPrefixes map[tcpip.Subnet]slaacPrefixState
+
+ // The last learned DHCPv6 configuration from an NDP RA.
+ dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
+
+ // temporaryIIDHistory is the history value used to generate a new temporary
+ // IID.
+ temporaryIIDHistory [header.IIDSize]byte
+
+ // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for
+ // temporary SLAAC addresses.
+ temporaryAddressDesyncFactor time.Duration
+}
+
+// dadState holds the Duplicate Address Detection timer and channel to signal
+// to the DAD goroutine that DAD should stop.
+type dadState struct {
+ // The DAD timer to send the next NS message, or resolve the address.
+ timer tcpip.Timer
+
+ // Used to let the DAD timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the IPv6 endpoint this dadState is associated with.
+ done *bool
+}
+
+// defaultRouterState holds data associated with a default router discovered by
+// a Router Advertisement (RA).
+type defaultRouterState struct {
+ // Job to invalidate the default router.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+}
+
+// onLinkPrefixState holds data associated with an on-link prefix discovered by
+// a Router Advertisement's Prefix Information option (PI) when the NDP
+// configurations was configured to do so.
+type onLinkPrefixState struct {
+ // Job to invalidate the on-link prefix.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+}
+
+// tempSLAACAddrState holds state associated with a temporary SLAAC address.
+type tempSLAACAddrState struct {
+ // Job to deprecate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ deprecationJob *tcpip.Job
+
+ // Job to invalidate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+
+ // Job to regenerate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ regenJob *tcpip.Job
+
+ createdAt time.Time
+
+ // The address's endpoint.
+ //
+ // Must not be nil.
+ addressEndpoint stack.AddressEndpoint
+
+ // Has a new temporary SLAAC address already been regenerated?
+ regenerated bool
+}
+
+// slaacPrefixState holds state associated with a SLAAC prefix.
+type slaacPrefixState struct {
+ // Job to deprecate the prefix.
+ //
+ // Must not be nil.
+ deprecationJob *tcpip.Job
+
+ // Job to invalidate the prefix.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+
+ // Nonzero only when the address is not valid forever.
+ validUntil time.Time
+
+ // Nonzero only when the address is not preferred forever.
+ preferredUntil time.Time
+
+ // State associated with the stable address generated for the prefix.
+ stableAddr struct {
+ // The address's endpoint.
+ //
+ // May only be nil when the address is being (re-)generated. Otherwise,
+ // must not be nil as all SLAAC prefixes must have a stable address.
+ addressEndpoint stack.AddressEndpoint
+
+ // The number of times an address has been generated locally where the IPv6
+ // endpoint already had the generated address.
+ localGenerationFailures uint8
+ }
+
+ // The temporary (short-lived) addresses generated for the SLAAC prefix.
+ tempAddrs map[tcpip.Address]tempSLAACAddrState
+
+ // The next two fields are used by both stable and temporary addresses
+ // generated for a SLAAC prefix. This is safe as only 1 address is in the
+ // generation and DAD process at any time. That is, no two addresses are
+ // generated at the same time for a given SLAAC prefix.
+
+ // The number of times an address has been generated and added to the IPv6
+ // endpoint.
+ //
+ // Addresses may be regenerated in reseponse to a DAD conflicts.
+ generationAttempts uint8
+
+ // The maximum number of times to attempt regeneration of a SLAAC address
+ // in response to DAD conflicts.
+ maxGenerationAttempts uint8
+}
+
+// startDuplicateAddressDetection performs Duplicate Address Detection.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+ // addr must be a valid unicast IPv6 address.
+ if !header.IsV6UnicastAddress(addr) {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should be marked as tentative since we are starting DAD.
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in the
+ // DAD process.
+ if _, ok := ndp.dad[addr]; ok {
+ // Should never happen because we should only ever call this function for
+ // newly created addresses. If we attemped to "add" an address that already
+ // existed, we would get an error since we attempted to add a duplicate
+ // address, or its reference count would have been increased without doing
+ // the work that would have been done for an address that was brand new.
+ // See endpoint.addAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
+
+ remaining := ndp.configs.DupAddrDetectTransmits
+ if remaining == 0 {
+ addressEndpoint.SetKind(stack.Permanent)
+
+ // Consider DAD to have resolved even if no DAD messages were actually
+ // transmitted.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
+ }
+
+ return nil
+ }
+
+ var done bool
+ var timer tcpip.Timer
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
+ ndp.ep.mu.Lock()
+ defer ndp.ep.mu.Unlock()
+
+ if done {
+ // If we reach this point, it means that the DAD timer fired after
+ // another goroutine already obtained the IPv6 endpoint lock and stopped
+ // DAD before this function obtained the NIC lock. Simply return here and
+ // do nothing further.
+ return
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
+
+ dadDone := remaining == 0
+
+ var err *tcpip.Error
+ if !dadDone {
+ // Use the unspecified address as the source address when performing DAD.
+ addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
+
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the IPv6 endpoint. Note,
+ // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
+ // the address was removed.
+ ndp.ep.mu.Unlock()
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ ndp.ep.mu.Lock()
+ addressEndpoint.DecRef()
+ }
+
+ if done {
+ // If we reach this point, it means that DAD was stopped after we released
+ // the IPv6 endpoint's read lock and before we obtained the write lock.
+ return
+ }
+
+ if dadDone {
+ // DAD has resolved.
+ addressEndpoint.SetKind(stack.Permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ timer.Reset(ndp.configs.RetransmitTimer)
+ return
+ }
+
+ // At this point we know that either DAD is done or we hit an error sending
+ // the last NDP NS. Either way, clean up addr's DAD state and let the
+ // integrator know DAD has completed.
+ delete(ndp.dad, addr)
+
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ }
+
+ // If DAD resolved for a stable SLAAC address, attempt generation of a
+ // temporary SLAAC address.
+ if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
+ })
+
+ ndp.dad[addr] = dadState{
+ timer: timer,
+ done: &done,
+ }
+
+ return nil
+}
+
+// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
+// addr.
+//
+// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
+//
+// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+ snmc := header.SolicitedNodeAddr(addr)
+
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // Route should resolve immediately since snmc is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the IPv6 endpoint is not
+ // locked, the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
+
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
+ }
+
+ icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmpData.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ ns.SetTargetAddress(addr)
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ sent.NeighborSolicit.Increment()
+
+ return nil
+}
+
+// stopDuplicateAddressDetection ends a running Duplicate Address Detection
+// process. Note, this may leave the DAD process for a tentative address in
+// such a state forever, unless some other external event resolves the DAD
+// process (receiving an NA from the true owner of addr, or an NS for addr
+// (implying another node is attempting to use addr)). It is up to the caller
+// of this function to handle such a scenario.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
+ dad, ok := ndp.dad[addr]
+ if !ok {
+ // Not currently performing DAD on addr, just return.
+ return
+ }
+
+ if dad.timer != nil {
+ dad.timer.Stop()
+ dad.timer = nil
+
+ *dad.done = true
+ dad.done = nil
+ }
+
+ delete(ndp.dad, addr)
+
+ // Let the integrator know DAD did not resolve.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
+ }
+}
+
+// handleRA handles a Router Advertisement message that arrived on the NIC
+// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ // Is the IPv6 endpoint configured to handle RAs at all?
+ //
+ // Currently, the stack does not determine router interface status on a
+ // per-interface basis; it is a protocol-wide configuration, so we check the
+ // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
+ // packets.
+ if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
+ return
+ }
+
+ // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
+ // only inform the dispatcher on configuration changes. We do nothing else
+ // with the information.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ var configuration DHCPv6ConfigurationFromNDPRA
+ switch {
+ case ra.ManagedAddrConfFlag():
+ configuration = DHCPv6ManagedAddress
+
+ case ra.OtherConfFlag():
+ configuration = DHCPv6OtherConfigurations
+
+ default:
+ configuration = DHCPv6NoConfiguration
+ }
+
+ if ndp.dhcpv6Configuration != configuration {
+ ndp.dhcpv6Configuration = configuration
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
+ }
+ }
+
+ // Is the IPv6 endpoint configured to discover default routers?
+ if ndp.configs.DiscoverDefaultRouters {
+ rtr, ok := ndp.defaultRouters[ip]
+ rl := ra.RouterLifetime()
+ switch {
+ case !ok && rl != 0:
+ // This is a new default router we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredDefaultRouters routers.
+ if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
+ ndp.rememberDefaultRouter(ip, rl)
+ }
+
+ case ok && rl != 0:
+ // This is an already discovered default router. Update
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
+ ndp.defaultRouters[ip] = rtr
+
+ case ok && rl == 0:
+ // We know about the router but it is no longer to be
+ // used as a default router so invalidate it.
+ ndp.invalidateDefaultRouter(ip)
+ }
+ }
+
+ // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
+ // Discovery.
+
+ // We know the options is valid as far as wire format is concerned since
+ // we got the Router Advertisement, as documented by this fn. Given this
+ // we do not check the iterator for errors on calls to Next.
+ it, _ := ra.Options().Iter(false)
+ for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
+ switch opt := opt.(type) {
+ case header.NDPRecursiveDNSServer:
+ if ndp.ep.protocol.ndpDisp == nil {
+ continue
+ }
+
+ addrs, _ := opt.Addresses()
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
+
+ case header.NDPDNSSearchList:
+ if ndp.ep.protocol.ndpDisp == nil {
+ continue
+ }
+
+ domainNames, _ := opt.DomainNames()
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
+
+ case header.NDPPrefixInformation:
+ prefix := opt.Subnet()
+
+ // Is the prefix a link-local?
+ if header.IsV6LinkLocalAddress(prefix.ID()) {
+ // ...Yes, skip as per RFC 4861 section 6.3.4,
+ // and RFC 4862 section 5.5.3.b (for SLAAC).
+ continue
+ }
+
+ // Is the Prefix Length 0?
+ if prefix.Prefix() == 0 {
+ // ...Yes, skip as this is an invalid prefix
+ // as all IPv6 addresses cannot be on-link.
+ continue
+ }
+
+ if opt.OnLinkFlag() {
+ ndp.handleOnLinkPrefixInformation(opt)
+ }
+
+ if opt.AutonomousAddressConfigurationFlag() {
+ ndp.handleAutonomousPrefixInformation(opt)
+ }
+ }
+
+ // TODO(b/141556115): Do (MTU) Parameter Discovery.
+ }
+}
+
+// invalidateDefaultRouter invalidates a discovered default router.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
+ rtr, ok := ndp.defaultRouters[ip]
+
+ // Is the router still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ rtr.invalidationJob.Cancel()
+ delete(ndp.defaultRouters, ip)
+
+ // Let the integrator know a discovered default router is invalidated.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
+ }
+}
+
+// rememberDefaultRouter remembers a newly discovered default router with IPv6
+// link-local address ip with lifetime rl.
+//
+// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+ ndpDisp := ndp.ep.protocol.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered a default router.
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
+ // Informed by the integrator to not remember the router, do
+ // nothing further.
+ return
+ }
+
+ state := defaultRouterState{
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateDefaultRouter(ip)
+ }),
+ }
+
+ state.invalidationJob.Schedule(rl)
+
+ ndp.defaultRouters[ip] = state
+}
+
+// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
+// address with prefix prefix with lifetime l.
+//
+// The prefix identified by prefix MUST NOT already be known.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
+ ndpDisp := ndp.ep.protocol.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered an on-link prefix.
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
+ // Informed by the integrator to not remember the prefix, do
+ // nothing further.
+ return
+ }
+
+ state := onLinkPrefixState{
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }),
+ }
+
+ if l < header.NDPInfiniteLifetime {
+ state.invalidationJob.Schedule(l)
+ }
+
+ ndp.onLinkPrefixes[prefix] = state
+}
+
+// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
+ s, ok := ndp.onLinkPrefixes[prefix]
+
+ // Is the on-link prefix still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ s.invalidationJob.Cancel()
+ delete(ndp.onLinkPrefixes, prefix)
+
+ // Let the integrator know a discovered on-link prefix is invalidated.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
+ }
+}
+
+// handleOnLinkPrefixInformation handles a Prefix Information option with
+// its on-link flag set, as per RFC 4861 section 6.3.4.
+//
+// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the on-link flag is set.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
+ prefix := pi.Subnet()
+ prefixState, ok := ndp.onLinkPrefixes[prefix]
+ vl := pi.ValidLifetime()
+
+ if !ok && vl == 0 {
+ // Don't know about this prefix but it has a zero valid
+ // lifetime, so just ignore.
+ return
+ }
+
+ if !ok && vl != 0 {
+ // This is a new on-link prefix we are discovering
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOnLinkPrefixes on-link prefixes.
+ if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes {
+ ndp.rememberOnLinkPrefix(prefix, vl)
+ }
+ return
+ }
+
+ if ok && vl == 0 {
+ // We know about the on-link prefix, but it is
+ // no longer to be considered on-link, so
+ // invalidate it.
+ ndp.invalidateOnLinkPrefix(prefix)
+ return
+ }
+
+ // This is an already discovered on-link prefix with a
+ // new non-zero valid lifetime.
+ //
+ // Update the invalidation job.
+
+ prefixState.invalidationJob.Cancel()
+
+ if vl < header.NDPInfiniteLifetime {
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
+ // the new valid lifetime.
+ prefixState.invalidationJob.Schedule(vl)
+ }
+
+ ndp.onLinkPrefixes[prefix] = prefixState
+}
+
+// handleAutonomousPrefixInformation handles a Prefix Information option with
+// its autonomous flag set, as per RFC 4862 section 5.5.3.
+//
+// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the autonomous flag is set.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
+ vl := pi.ValidLifetime()
+ pl := pi.PreferredLifetime()
+
+ // If the preferred lifetime is greater than the valid lifetime,
+ // silently ignore the Prefix Information option, as per RFC 4862
+ // section 5.5.3.c.
+ if pl > vl {
+ return
+ }
+
+ prefix := pi.Subnet()
+
+ // Check if we already maintain SLAAC state for prefix.
+ if state, ok := ndp.slaacPrefixes[prefix]; ok {
+ // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
+ ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl)
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC.
+ if !ndp.configs.AutoGenGlobalAddresses {
+ return
+ }
+
+ ndp.doSLAAC(prefix, pl, vl)
+}
+
+// doSLAAC generates a new SLAAC address with the provided lifetimes
+// for prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
+ // If we do not already have an address for this prefix and the valid
+ // lifetime is 0, no need to do anything further, as per RFC 4862
+ // section 5.5.3.d.
+ if vl == 0 {
+ return
+ }
+
+ // Make sure the prefix is valid (as far as its length is concerned) to
+ // generate a valid IPv6 address from an interface identifier (IID), as
+ // per RFC 4862 sectiion 5.5.3.d.
+ if prefix.Prefix() != validPrefixLenForAutoGen {
+ return
+ }
+
+ state := slaacPrefixState{
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
+ }
+
+ ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
+ }),
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }),
+ tempAddrs: make(map[tcpip.Address]tempSLAACAddrState),
+ maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
+ }
+
+ now := time.Now()
+
+ // The time an address is preferred until is needed to properly generate the
+ // address.
+ if pl < header.NDPInfiniteLifetime {
+ state.preferredUntil = now.Add(pl)
+ }
+
+ if !ndp.generateSLAACAddr(prefix, &state) {
+ // We were unable to generate an address for the prefix, we do not nothing
+ // further as there is no reason to maintain state or jobs for a prefix we
+ // do not have an address for.
+ return
+ }
+
+ // Setup the initial jobs to deprecate and invalidate prefix.
+
+ if pl < header.NDPInfiniteLifetime && pl != 0 {
+ state.deprecationJob.Schedule(pl)
+ }
+
+ if vl < header.NDPInfiniteLifetime {
+ state.invalidationJob.Schedule(vl)
+ state.validUntil = now.Add(vl)
+ }
+
+ // If the address is assigned (DAD resolved), generate a temporary address.
+ if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
+ }
+
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
+ // Inform the integrator that we have a new SLAAC address.
+ ndpDisp := ndp.ep.protocol.ndpDisp
+ if ndpDisp == nil {
+ return nil
+ }
+
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
+ // Informed by the integrator not to add the address.
+ return nil
+ }
+
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ if err != nil {
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
+ }
+
+ return addressEndpoint
+}
+
+// generateSLAACAddr generates a SLAAC address for prefix.
+//
+// Returns true if an address was successfully generated.
+//
+// Panics if the prefix is not a SLAAC prefix or it already has an address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix()))
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if state.generationAttempts == state.maxGenerationAttempts {
+ return false
+ }
+
+ var generatedAddr tcpip.AddressWithPrefix
+ addrBytes := []byte(prefix.ID())
+
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
+ if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
+ dadCounter,
+ oIID.SecretKey,
+ )
+ } else if dadCounter == 0 {
+ // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if
+ // the DAD counter is non-zero, we cannot use this method.
+ //
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ linkAddr := ndp.ep.nic.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return false
+ }
+
+ // Generate an address within prefix from the modified EUI-64 of ndp's
+ // NIC's Ethernet MAC address.
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ } else {
+ // We have no way to regenerate an address in response to an address
+ // conflict when addresses are not generated with opaque IIDs.
+ return false
+ }
+
+ generatedAddr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
+ }
+
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
+ break
+ }
+
+ state.stableAddr.localGenerationFailures++
+ }
+
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ state.stableAddr.addressEndpoint = addressEndpoint
+ state.generationAttempts++
+ return true
+ }
+
+ return false
+}
+
+// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
+//
+// If generating a new address for the prefix fails, the prefix is invalidated.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix))
+ }
+
+ if ndp.generateSLAACAddr(prefix, &state) {
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // We were unable to generate a permanent address for the SLAAC prefix so
+ // invalidate the prefix as there is no reason to maintain state for a
+ // SLAAC prefix we do not have an address for.
+ ndp.invalidateSLAACPrefix(prefix, state)
+}
+
+// generateTempSLAACAddr generates a new temporary SLAAC address.
+//
+// If resetGenAttempts is true, the prefix's generation counter is reset.
+//
+// Returns true if a new address was generated.
+func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
+ // Are we configured to auto-generate new temporary global addresses for the
+ // prefix?
+ if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() {
+ return false
+ }
+
+ if resetGenAttempts {
+ prefixState.generationAttempts = 0
+ prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if prefixState.generationAttempts == prefixState.maxGenerationAttempts {
+ return false
+ }
+
+ stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address
+ now := time.Now()
+
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime.
+ vl := ndp.configs.MaxTempAddrValidLifetime
+ if prefixState.validUntil != (time.Time{}) {
+ if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
+ vl = prefixVL
+ }
+ }
+
+ if vl <= 0 {
+ // Cannot create an address without a valid lifetime.
+ return false
+ }
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or the
+ // maximum temporary address preferred lifetime - the temporary address desync
+ // factor.
+ pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
+ if prefixState.preferredUntil != (time.Time{}) {
+ if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
+ // Respect the preferred lifetime of the prefix, as per RFC 4941 section
+ // 3.3 step 4.
+ pl = prefixPL
+ }
+ }
+
+ // As per RFC 4941 section 3.3 step 5, a temporary address is created only if
+ // the calculated preferred lifetime is greater than the advance regeneration
+ // duration. In particular, we MUST NOT create a temporary address with a zero
+ // Preferred Lifetime.
+ if pl <= ndp.configs.RegenAdvanceDuration {
+ return false
+ }
+
+ // Attempt to generate a new address that is not already assigned to the IPv6
+ // endpoint.
+ var generatedAddr tcpip.AddressWithPrefix
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
+ break
+ }
+ }
+
+ // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
+ // address with a zero preferred lifetime. The checks above ensure this
+ // so we know the address is not deprecated.
+ addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ state := tempSLAACAddrState{
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
+ }
+
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
+ }),
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr))
+ }
+
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
+ }),
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr))
+ }
+
+ // If an address has already been regenerated for this address, don't
+ // regenerate another address.
+ if tempAddrState.regenerated {
+ return
+ }
+
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */)
+ prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
+ ndp.slaacPrefixes[prefix] = prefixState
+ }),
+ createdAt: now,
+ addressEndpoint: addressEndpoint,
+ }
+
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
+
+ prefixState.generationAttempts++
+ prefixState.tempAddrs[generatedAddr.Address] = state
+
+ return true
+}
+
+// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix))
+ }
+
+ ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts)
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
+ // If the preferred lifetime is zero, then the prefix should be deprecated.
+ deprecated := pl == 0
+ if deprecated {
+ ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint)
+ } else {
+ prefixState.stableAddr.addressEndpoint.SetDeprecated(false)
+ }
+
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
+
+ now := time.Now()
+
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
+ if pl < header.NDPInfiniteLifetime {
+ if !deprecated {
+ prefixState.deprecationJob.Schedule(pl)
+ }
+ prefixState.preferredUntil = now.Add(pl)
+ } else {
+ prefixState.preferredUntil = time.Time{}
+ }
+
+ // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or greater than
+ // RemainingLifetime, set the valid lifetime of the prefix to the
+ // advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
+ // advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
+
+ if vl >= header.NDPInfiniteLifetime {
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
+ prefixState.validUntil = time.Time{}
+ } else {
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the prefix was originally set to be valid forever, assume the
+ // remaining time to be the maximum possible value.
+ if prefixState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(prefixState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl > MinPrefixInformationValidLifetimeForUpdate {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ if effectiveVl != 0 {
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
+ prefixState.validUntil = now.Add(effectiveVl)
+ }
+ }
+
+ // If DAD is not yet complete on the stable address, there is no need to do
+ // work with temporary addresses.
+ if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent {
+ return
+ }
+
+ // Note, we do not need to update the entries in the temporary address map
+ // after updating the jobs because the jobs are held as pointers.
+ var regenForAddr tcpip.Address
+ allAddressesRegenerated := true
+ for tempAddr, tempAddrState := range prefixState.tempAddrs {
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime. Note, the valid lifetime of a
+ // temporary address is relative to the address's creation time.
+ validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
+ if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 {
+ validUntil = prefixState.validUntil
+ }
+
+ // If the address is no longer valid, invalidate it immediately. Otherwise,
+ // reset the invalidation job.
+ newValidLifetime := validUntil.Sub(now)
+ if newValidLifetime <= 0 {
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
+ continue
+ }
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or
+ // the maximum temporary address preferred lifetime - the temporary address
+ // desync factor. Note, the preferred lifetime of a temporary address is
+ // relative to the address's creation time.
+ preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
+ if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
+ preferredUntil = prefixState.preferredUntil
+ }
+
+ // If the address is no longer preferred, deprecate it immediately.
+ // Otherwise, schedule the deprecation job again.
+ newPreferredLifetime := preferredUntil.Sub(now)
+ tempAddrState.deprecationJob.Cancel()
+ if newPreferredLifetime <= 0 {
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
+ } else {
+ tempAddrState.addressEndpoint.SetDeprecated(false)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
+ }
+
+ tempAddrState.regenJob.Cancel()
+ if tempAddrState.regenerated {
+ } else {
+ allAddressesRegenerated = false
+
+ if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration {
+ // The new preferred lifetime is less than the advance regeneration
+ // duration so regenerate an address for this temporary address
+ // immediately after we finish iterating over the temporary addresses.
+ regenForAddr = tempAddr
+ } else {
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ }
+ }
+ }
+
+ // Generate a new temporary address if all of the existing temporary addresses
+ // have been regenerated, or we need to immediately regenerate an address
+ // due to an update in preferred lifetime.
+ //
+ // If each temporay address has already been regenerated, no new temporary
+ // address is generated. To ensure continuation of temporary SLAAC addresses,
+ // we manually try to regenerate an address here.
+ if len(regenForAddr) != 0 || allAddressesRegenerated {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok {
+ state.regenerated = true
+ prefixState.tempAddrs[regenForAddr] = state
+ }
+ }
+}
+
+// deprecateSLAACAddress marks the address as deprecated and notifies the NDP
+// dispatcher that address has been deprecated.
+//
+// deprecateSLAACAddress does nothing if the address is already deprecated.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) {
+ if addressEndpoint.Deprecated() {
+ return
+ }
+
+ addressEndpoint.SetDeprecated(true)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
+ }
+}
+
+// invalidateSLAACPrefix invalidates a SLAAC prefix.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
+ // Since we are already invalidating the prefix, do not invalidate the
+ // prefix when removing the address.
+ if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
+ }
+ }
+}
+
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
+// resources.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address {
+ return
+ }
+
+ if !invalidatePrefix {
+ // If the prefix is not being invalidated, disassociate the address from the
+ // prefix and do nothing further.
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
+
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
+//
+// Panics if the SLAAC prefix is not known.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
+ // Invalidate all temporary addresses.
+ for tempAddr, tempAddrState := range state.tempAddrs {
+ ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
+ }
+
+ if state.stableAddr.addressEndpoint != nil {
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
+ }
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
+ delete(ndp.slaacPrefixes, prefix)
+}
+
+// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ // Since we are already invalidating the address, do not invalidate the
+ // address when removing the address.
+ if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
+// SLAAC address's resources from ndp.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
+ }
+
+ if !invalidateAddr {
+ return
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr))
+ }
+
+ tempAddrState, ok := state.tempAddrs[addr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
+// jobs and entry.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ tempAddrState.addressEndpoint.DecRef()
+ tempAddrState.addressEndpoint = nil
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
+ delete(tempAddrs, tempAddr)
+}
+
+// removeSLAACAddresses removes all SLAAC addresses.
+//
+// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
+ linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
+ var linkLocalPrefixes int
+ for prefix, state := range ndp.slaacPrefixes {
+ // RFC 4862 section 5 states that routers are also expected to generate a
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if keepLinkLocal && prefix == linkLocalSubnet {
+ linkLocalPrefixes++
+ continue
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }
+
+ if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
+ panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
+ }
+}
+
+// cleanupState cleans up ndp's state.
+//
+// If hostOnly is true, then only host-specific state is cleaned up.
+//
+// This function invalidates all discovered on-link prefixes, discovered
+// routers, and auto-generated addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address aren't
+// invalidated as routers are also expected to generate a link-local address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
+
+ for prefix := range ndp.onLinkPrefixes {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }
+
+ if got := len(ndp.onLinkPrefixes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
+ }
+
+ for router := range ndp.defaultRouters {
+ ndp.invalidateDefaultRouter(router)
+ }
+
+ if got := len(ndp.defaultRouters); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ }
+
+ ndp.dhcpv6Configuration = 0
+}
+
+// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
+// 6.3.7. If routers are already being solicited, this function does nothing.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) startSolicitingRouters() {
+ if ndp.rtrSolicit.timer != nil {
+ // We are already soliciting routers.
+ return
+ }
+
+ remaining := ndp.configs.MaxRtrSolicitations
+ if remaining == 0 {
+ return
+ }
+
+ // Calculate the random delay before sending our first RS, as per RFC
+ // 4861 section 6.3.7.
+ var delay time.Duration
+ if ndp.configs.MaxRtrSolicitationDelay > 0 {
+ delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ }
+
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
+ ndp.ep.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the IPv6 endpoint lock and stopped
+ // solicitations. Simply return here and do nothing further.
+ ndp.ep.mu.Unlock()
+ return
+ }
+
+ // As per RFC 4861 section 4.1, the source of the RS is an address assigned
+ // to the sending interface, or the unspecified address if no address is
+ // assigned to the sending interface.
+ addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
+ if addressEndpoint == nil {
+ // Incase this ends up creating a new temporary address, we need to hold
+ // onto the endpoint until a route is obtained. If we decrement the
+ // reference count before obtaing a route, the address's resources would
+ // be released and attempting to obtain a route after would fail. Once a
+ // route is obtainted, it is safe to decrement the reference count since
+ // obtaining a route increments the address's reference count.
+ addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
+ }
+ ndp.ep.mu.Unlock()
+
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ addressEndpoint.DecRef()
+ if err != nil {
+ return
+ }
+ defer r.Release()
+
+ // Route should resolve immediately since
+ // header.IPv6AllRoutersMulticastAddress is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the IPv6 endpoint is
+ // not locked, the IPv6 endpoint could have been disabled or removed by
+ // another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
+ }
+
+ // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+ // link-layer address option if the source address of the NDP RS is
+ // specified. This option MUST NOT be included if the source address is
+ // unspecified.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ var optsSerializer header.NDPOptionsSerializer
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ optsSerializer = header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ }
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
+ icmpData := header.ICMPv6(buffer.NewView(payloadSize))
+ icmpData.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs.Options().Serialize(optsSerializer)
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.RouterSolicit.Increment()
+ remaining--
+ }
+
+ ndp.ep.mu.Lock()
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
+ // Note, we need to explicitly check to make sure that
+ // the timer field is not nil because if it was nil but
+ // we still reached this point, then we know the IPv6 endpoint
+ // was requested to stop soliciting routers so we don't
+ // need to send the next Router Solicitation message.
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ }
+ ndp.ep.mu.Unlock()
+ })
+
+}
+
+// stopSolicitingRouters stops soliciting routers. If routers are not currently
+// being solicited, this function does nothing.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopSolicitingRouters() {
+ if ndp.rtrSolicit.timer == nil {
+ // Nothing to do.
+ return
+ }
+
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+}
+
+// initializeTempAddrState initializes state related to temporary SLAAC
+// addresses.
+func (ndp *ndpState) initializeTempAddrState() {
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
+
+ if MaxDesyncFactor != 0 {
+ ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 12b70f7e9..ac20f217e 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -15,9 +15,12 @@
package ipv6
import (
+ "context"
"strings"
"testing"
+ "time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -30,12 +33,13 @@ import (
// setupStackAndEndpoint creates a stack with a single NIC with a link-local
// address llladdr and an IPv6 endpoint to a remote with link-local address
// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
@@ -63,14 +67,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
+ t.Cleanup(ep.Close)
return s, ep
}
+var _ NDPDispatcher = (*testNDPDispatcher)(nil)
+
+// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
+type testNDPDispatcher struct {
+ addr tcpip.Address
+}
+
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+ t.addr = addr
+ return true
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+ t.addr = addr
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) {
+}
+
+func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
+ var ndpDisp testNDPDispatcher
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
+ }
+
+ ipv6EP := ep.(*endpoint)
+ ipv6EP.mu.Lock()
+ ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.Unlock()
+
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+
+ ndpDisp.addr = ""
+ ndpEP := ep.(stack.NDPEndpoint)
+ ndpEP.InvalidateDefaultRouter(lladdr1)
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+}
+
// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -100,7 +184,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
e := channel.New(0, 1280, linkAddr0)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -136,9 +220,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
if linkAddr != test.expectedLinkAddr {
@@ -174,6 +258,123 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
}
}
+// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NS message with the Source Link Layer Address
+// option results in a new entry in the link address cache for the sender of
+// the message.
+func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if !ok {
+ t.Fatalf("expected a neighbor entry for %q", lladdr1)
+ }
+ if neigh.LinkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr)
+ }
+ if neigh.State != stack.Stale {
+ t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+
+ if ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+ }
+ })
+ }
+}
+
func TestNeighorSolicitationResponse(t *testing.T) {
const nicID = 1
nicAddr := lladdr0
@@ -183,26 +384,41 @@ func TestNeighorSolicitationResponse(t *testing.T) {
remoteLinkAddr0 := linkAddr1
remoteLinkAddr1 := linkAddr2
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
- name string
- nsOpts header.NDPOptionsSerializer
- nsSrcLinkAddr tcpip.LinkAddress
- nsSrc tcpip.Address
- nsDst tcpip.Address
- nsInvalid bool
- naDstLinkAddr tcpip.LinkAddress
- naSolicited bool
- naSrc tcpip.Address
- naDst tcpip.Address
+ name string
+ nsOpts header.NDPOptionsSerializer
+ nsSrcLinkAddr tcpip.LinkAddress
+ nsSrc tcpip.Address
+ nsDst tcpip.Address
+ nsInvalid bool
+ naDstLinkAddr tcpip.LinkAddress
+ naSolicited bool
+ naSrc tcpip.Address
+ naDst tcpip.Address
+ performsLinkResolution bool
}{
{
- name: "Unspecified source to multicast destination",
+ name: "Unspecified source to solicited-node multicast destination",
nsOpts: nil,
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
nsDst: nicAddrSNMC,
nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
+ naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
naSolicited: false,
naSrc: nicAddr,
naDst: header.IPv6AllNodesMulticastAddress,
@@ -223,11 +439,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: false,
- naSrc: nicAddr,
- naDst: header.IPv6AllNodesMulticastAddress,
+ nsInvalid: true,
},
{
name: "Unspecified source with source ll option to unicast destination",
@@ -239,7 +451,6 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsDst: nicAddr,
nsInvalid: true,
},
-
{
name: "Specified source with 1 source ll to multicast destination",
nsOpts: header.NDPOptionsSerializer{
@@ -299,6 +510,10 @@ func TestNeighorSolicitationResponse(t *testing.T) {
naSolicited: true,
naSrc: nicAddr,
naDst: remoteAddr,
+ // Since we send a unicast solicitations to a node without an entry for
+ // the remote, the node needs to perform neighbor discovery to get the
+ // remote's link address to send the advertisement response.
+ performsLinkResolution: true,
},
{
name: "Specified source with 1 source ll to unicast destination",
@@ -341,86 +556,159 @@ func TestNeighorSolicitationResponse(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(1, 1280, nicLinkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(nicAddr)
- opts := ns.Options()
- opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
- })
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
- if test.nsInvalid {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
- if p, got := e.Read(); got {
- t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
- }
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
- // If we expected the NS to be invalid, we have nothing else to check.
- return
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ if test.performsLinkResolution {
+ p, got := e.ReadContext(context.Background())
+ if !got {
+ t.Fatal("expected an NDP NS response")
+ }
+
+ if p.Route.LocalAddress != nicAddr {
+ t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr)
+ }
+ if p.Route.LocalLinkAddress != nicLinkAddr {
+ t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
+ }
+ respNSDst := header.SolicitedNodeAddr(test.nsSrc)
+ if p.Route.RemoteAddress != respNSDst {
+ t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
+ }
+ if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(nicAddr),
+ checker.DstAddr(respNSDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(test.nsSrc),
+ checker.NDPNSOptions([]header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(nicLinkAddr),
+ }),
+ ))
+
+ ser := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ }
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(test.nsSrc)
+ na.Options().Serialize(ser)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
+ })
+ e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
- p, got := e.Read()
- if !got {
- t.Fatal("expected an NDP NA response")
- }
+ p, got := e.ReadContext(context.Background())
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
- }
+ if p.Route.LocalAddress != test.naSrc {
+ t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc)
+ }
+ if p.Route.LocalLinkAddress != nicLinkAddr {
+ t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
+ }
+ if p.Route.RemoteAddress != test.naDst {
+ t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
+ }
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(test.naSrc),
- checker.DstAddr(test.naDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNA(
- checker.NDPNASolicitedFlag(test.naSolicited),
- checker.NDPNATargetAddress(nicAddr),
- checker.NDPNAOptions([]header.NDPOption{
- header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
- }),
- ))
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
+ }
})
}
}
@@ -461,7 +749,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
e := channel.New(0, 1280, linkAddr0)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -497,9 +785,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
if linkAddr != test.expectedLinkAddr {
@@ -535,200 +823,385 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
}
-func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
- }
-
- handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- if atomicFragment {
- bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength)
- bytes[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
- }
-
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
- ep.HandlePacket(r, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
+// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NA message with the Target Link Layer Address
+// option does not result in a new entry in the neighbor cache for the target
+// of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ tests := []struct {
+ name string
+ optsBuf []byte
+ isValid bool
}{
{
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ isValid: true,
},
{
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
},
{
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
},
{
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
},
},
}
- subTests := []struct {
- name string
- atomicFragment bool
- hopLimit uint8
- code uint8
- valid bool
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+
+ if test.isValid {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNDPValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
}{
{
- name: "Valid",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: true,
- },
- {
- name: "Fragmented",
- atomicFragment: true,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid hop limit",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit - 1,
- code: 0,
- valid: false,
+ name: "linkAddrCache",
+ useNeighborCache: false,
},
{
- name: "Invalid ICMPv6 code",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 1,
- valid: false,
+ name: "neighborCache",
+ useNeighborCache: true,
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- for _, test := range subTests {
- t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetCode(test.code)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- // Rx count of the NDP message should initially be 0.
- if got := typStat.Value(); got != 0 {
- t.Errorf("got %s = %d, want = 0", typ.name, got)
- }
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
+ return s, ep, r
+ }
- if t.Failed() {
- t.FailNow()
- }
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
+ if atomicFragment {
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
- handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ Data: payload.ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload) + len(extensions)),
+ NextHeader: nextHdr,
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, pkt)
+ }
- // Rx count of the NDP packet should have increased.
- if got := typStat.Value(); got != 1 {
- t.Errorf("got %s = %d, want = 1", typ.name, got)
- }
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
- want := uint64(0)
- if !test.valid {
- // Invalid count should have increased.
- want = 1
- }
- if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ var sllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
+ }{
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ routerOnly: true,
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ extraData: sllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code header.ICMPv6Code
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
}
- })
+
+ t.Run(name, func(t *testing.T) {
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ if got := routerOnly.Value(); got != 0 {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+
+ want = 0
+ if test.valid && !isRouter && typ.routerOnly {
+ // RouterOnlyPacketsReceivedByHost count should have increased.
+ want = 1
+ }
+ if got := routerOnly.Value(); got != want {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ }
+
+ })
+ }
+ })
+ }
}
})
}
+
}
// TestRouterAdvertValidation tests that when the NIC is configured to handle
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
func TestRouterAdvertValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
name string
src tcpip.Address
hopLimit uint8
- code uint8
+ code header.ICMPv6Code
ndpPayload []byte
expectedSuccess bool
}{
@@ -845,61 +1318,67 @@ func TestRouterAdvertValidation(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
- e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
}
})
}
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
new file mode 100644
index 000000000..d0ffc299a
--- /dev/null
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ srcs = [
+ "testutil.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/network/fragmentation:__pkg__",
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
new file mode 100644
index 000000000..7cc52985e
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// MockLinkEndpoint is an endpoint used for testing, it stores packets written
+// to it and can mock errors.
+type MockLinkEndpoint struct {
+ // WrittenPackets is where packets written to the endpoint are stored.
+ WrittenPackets []*stack.PacketBuffer
+
+ mtu uint32
+ err *tcpip.Error
+ allowPackets int
+}
+
+// NewMockLinkEndpoint creates a new MockLinkEndpoint.
+//
+// err is the error that will be returned once allowPackets packets are written
+// to the endpoint.
+func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+ return &MockLinkEndpoint{
+ mtu: mtu,
+ err: err,
+ allowPackets: allowPackets,
+ }
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var n int
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+ return n, err
+ }
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*MockLinkEndpoint) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*MockLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+ var views buffer.VectorisedView
+
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ if _, err := rand.Read(newView); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ views.AppendView(newView)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+ Data: views,
+ })
+ pkt.NetworkProtocolNumber = proto
+ if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
+}