summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go47
-rw-r--r--pkg/tcpip/network/arp/arp_test.go331
-rw-r--r--pkg/tcpip/network/ip_test.go14
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go2
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go278
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go447
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go875
9 files changed, 1479 insertions, 520 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index eddf7b725..82c073e32 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -28,5 +28,6 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 920872c3f..cbbe5b77f 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -46,6 +46,7 @@ type endpoint struct {
nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
}
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
@@ -78,7 +79,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return e.protocol.Number()
+ return ProtocolNumber
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
@@ -99,9 +100,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
- return // we have no useful answer, ignore the request
+
+ if e.nud == nil {
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ } else {
+ if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ remoteAddr := tcpip.Address(h.ProtocolAddressSender())
+ remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
})
@@ -113,11 +130,28 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
_ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
- fallthrough // also fill the cache from requests
+
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+
+ if e.nud == nil {
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ return
+ }
+
+ // The solicited, override, and isRouter flags are not available for ARP;
+ // they are only available for IPv6 Neighbor Advertisements.
+ e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{
+ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies
+ // are handled equivalently to a solicited Neighbor Advertisement.
+ Solicited: true,
+ // If a different link address is received than the one cached, the entry
+ // should always go to Stale.
+ Override: false,
+ // ARP does not distinguish between router and non-router hosts.
+ IsRouter: false,
+ })
}
}
@@ -134,12 +168,13 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
protocol: p,
nicID: nicID,
linkEP: sender,
linkAddrCache: linkAddrCache,
+ nud: nud,
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index c2c3e6891..9c9a859e3 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -16,10 +16,12 @@ package arp_test
import (
"context"
+ "fmt"
"strconv"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -32,57 +34,192 @@ import (
)
const (
- stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+ nicID = 1
+
+ stackAddr = tcpip.Address("\x0a\x00\x00\x01")
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+
+ remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
+ remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+
+ unknownAddr = tcpip.Address("\x0a\x00\x00\x03")
defaultChannelSize = 1
defaultMTU = 65536
+
+ // eventChanSize defines the size of event channels used by the neighbor
+ // cache's event dispatcher. The size chosen here needs to be sufficient to
+ // queue all the events received during tests before consumption.
+ // If eventChanSize is too small, the tests may deadlock.
+ eventChanSize = 32
+)
+
+type eventType uint8
+
+const (
+ entryAdded eventType = iota
+ entryChanged
+ entryRemoved
)
+func (t eventType) String() string {
+ switch t {
+ case entryAdded:
+ return "add"
+ case entryChanged:
+ return "change"
+ case entryRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+type eventInfo struct {
+ eventType eventType
+ nicID tcpip.NICID
+ addr tcpip.Address
+ linkAddr tcpip.LinkAddress
+ state stack.NeighborState
+}
+
+func (e eventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state)
+}
+
+// arpDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type arpDispatcher struct {
+ // C is where events are queued
+ C chan eventInfo
+}
+
+var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
+
+func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryChanged,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryRemoved,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
+ select {
+ case got := <-d.C:
+ if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" {
+ return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
+ }
+ case <-ctx.Done():
+ return fmt.Errorf("%s for %s", ctx.Err(), want)
+ }
+ return nil
+}
+
+func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ return d.waitForEvent(ctx, want)
+}
+
+func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
+ select {
+ case event := <-d.C:
+ return event, true
+ default:
+ return eventInfo{}, false
+ }
+}
+
type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
+ s *stack.Stack
+ linkEP *channel.Endpoint
+ nudDisp *arpDispatcher
}
-func newTestContext(t *testing.T) *testContext {
+func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
+ c := stack.DefaultNUDConfigurations()
+ // Transition from Reachable to Stale almost immediately to test if receiving
+ // probes refreshes positive reachability.
+ c.BaseReachableTime = time.Microsecond
+
+ d := arpDispatcher{
+ // Create an event channel large enough so the neighbor cache doesn't block
+ // while dispatching events. Blocking could interfere with the timing of
+ // NUD transitions.
+ C: make(chan eventInfo, eventChanSize),
+ }
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ NUDConfigs: c,
+ NUDDisp: &d,
+ UseNeighborCache: useNeighborCache,
})
- ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNIC(nicID, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ if !useNeighborCache {
+ // The remote address needs to be assigned to the NIC so we can receive and
+ // verify outgoing ARP packets. The neighbor cache isn't concerned with
+ // this; the tests that use linkAddrCache expect the ARP responses to be
+ // received by the same NIC.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
}
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
t.Fatalf("AddAddress for arp failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
}})
return &testContext{
- t: t,
- s: s,
- linkEP: ep,
+ s: s,
+ linkEP: ep,
+ nudDisp: &d,
}
}
@@ -91,7 +228,7 @@ func (c *testContext) cleanup() {
}
func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
+ c := newTestContext(t, false /* useNeighborCache */)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
@@ -111,7 +248,7 @@ func TestDirectRequest(t *testing.T) {
}))
}
- for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ for i, address := range []tcpip.Address{stackAddr, remoteAddr} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
pi, _ := c.linkEP.ReadContext(context.Background())
@@ -122,7 +259,7 @@ func TestDirectRequest(t *testing.T) {
if !rep.IsValid() {
t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
}
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want {
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
@@ -137,7 +274,7 @@ func TestDirectRequest(t *testing.T) {
})
}
- inject(stackAddrBad)
+ inject(unknownAddr)
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
@@ -148,6 +285,144 @@ func TestDirectRequest(t *testing.T) {
}
}
+func TestDirectRequestWithNeighborCache(t *testing.T) {
+ c := newTestContext(t, true /* useNeighborCache */)
+ defer c.cleanup()
+
+ tests := []struct {
+ name string
+ senderAddr tcpip.Address
+ senderLinkAddr tcpip.LinkAddress
+ targetAddr tcpip.Address
+ isValid bool
+ }{
+ {
+ name: "Loopback",
+ senderAddr: stackAddr,
+ senderLinkAddr: stackLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "Remote",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "RemoteInvalidTarget",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: unknownAddr,
+ isValid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Inject an incoming ARP request.
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), test.senderLinkAddr)
+ copy(h.ProtocolAddressSender(), test.senderAddr)
+ copy(h.ProtocolAddressTarget(), test.targetAddr)
+ c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
+
+ if !test.isValid {
+ // No packets should be sent after receiving an invalid ARP request.
+ // There is no need to perform a blocking read here, since packets are
+ // sent in the same function that handles ARP requests.
+ if pkt, ok := c.linkEP.Read(); ok {
+ t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto)
+ }
+ return
+ }
+
+ // Verify an ARP response was sent.
+ pi, ok := c.linkEP.Read()
+ if !ok {
+ t.Fatal("expected ARP response to be sent, got none")
+ }
+
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
+ }
+ rep := header.ARP(pi.Pkt.NetworkHeader().View())
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want)
+ }
+
+ // Verify the sender was saved in the neighbor cache.
+ wantEvent := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ addr: test.senderAddr,
+ linkAddr: tcpip.LinkAddress(test.senderLinkAddr),
+ state: stack.Stale,
+ }
+ if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
+ t.Fatal(err)
+ }
+
+ neighbors, err := c.s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("c.s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff)
+ }
+ t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ neigh, ok := neighborByAddr[test.senderAddr]
+ if !ok {
+ t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr)
+ }
+ if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
+ t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
+ }
+ if got, want := neigh.LocalAddr, stackAddr; got != want {
+ t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want)
+ }
+ if got, want := neigh.State, stack.Stale; got != want {
+ t.Errorf("got neighbor State = %s, want = %s", got, want)
+ }
+
+ // No more events should be dispatched
+ for {
+ event, ok := c.nudDisp.nextEvent()
+ if !ok {
+ break
+ }
+ t.Errorf("unexpected %s", event)
+ }
+ })
+ }
+}
+
func TestLinkAddressRequest(t *testing.T) {
tests := []struct {
name string
@@ -156,8 +431,8 @@ func TestLinkAddressRequest(t *testing.T) {
}{
{
name: "Unicast",
- remoteLinkAddr: stackLinkAddr2,
- expectLinkAddr: stackLinkAddr2,
+ remoteLinkAddr: remoteLinkAddr,
+ expectLinkAddr: remoteLinkAddr,
},
{
name: "Multicast",
@@ -173,9 +448,9 @@ func TestLinkAddressRequest(t *testing.T) {
t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
}
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
- if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err)
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err)
}
pkt, ok := linkEP.Read()
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 9007346fe..e45dd17f8 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -250,7 +250,7 @@ func buildDummyStack(t *testing.T) *stack.Stack {
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t))
defer ep.Close()
// Allocate and initialize the payload view.
@@ -287,7 +287,7 @@ func TestIPv4Send(t *testing.T) {
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
totalLen := header.IPv4MinimumSize + 30
@@ -357,7 +357,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
@@ -418,7 +418,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
totalLen := header.IPv4MinimumSize + 24
@@ -495,7 +495,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t))
defer ep.Close()
// Allocate and initialize the payload view.
@@ -532,7 +532,7 @@ func TestIPv6Send(t *testing.T) {
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
totalLen := header.IPv6MinimumSize + 30
@@ -611,7 +611,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 63ffb3660..55ca94268 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -59,7 +59,7 @@ type endpoint struct {
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
nicID: nicID,
linkEP: linkEP,
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 66d3a953a..2b83c421e 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,8 +15,6 @@
package ipv6
import (
- "fmt"
-
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -71,6 +69,59 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
+// getLinkAddrOption searches NDP options for a given link address option using
+// the provided getAddr function as a filter. Returns the link address if
+// found; otherwise, returns the zero link address value. Also returns true if
+// the options are valid as per the wire format, false otherwise.
+func getLinkAddrOption(it header.NDPOptionIterator, getAddr func(header.NDPOption) tcpip.LinkAddress) (tcpip.LinkAddress, bool) {
+ var linkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ return "", false
+ }
+ if done {
+ break
+ }
+ if addr := getAddr(opt); len(addr) != 0 {
+ // No RFCs define what to do when an NDP message has multiple Link-Layer
+ // Address options. Since no interface can have multiple link-layer
+ // addresses, we consider such messages invalid.
+ if len(linkAddr) != 0 {
+ return "", false
+ }
+ linkAddr = addr
+ }
+ }
+ return linkAddr, true
+}
+
+// getSourceLinkAddr searches NDP options for the source link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getSourceLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if src, ok := opt.(header.NDPSourceLinkLayerAddressOption); ok {
+ return src.EthernetAddress()
+ }
+ return ""
+ })
+}
+
+// getTargetLinkAddr searches NDP options for the target link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if dst, ok := opt.(header.NDPTargetLinkLayerAddressOption); ok {
+ return dst.EthernetAddress()
+ }
+ return ""
+ })
+}
+
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
@@ -137,7 +188,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
@@ -147,14 +198,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// NDP messages cannot be fragmented. Also note that in the common case NDP
// datagrams are very small and ToView() will not incur allocations.
ns := header.NDPNeighborSolicit(payload.ToView())
- it, err := ns.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NS option, drop the packet.
+ targetAddr := ns.TargetAddress()
+
+ // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
+ // address.
+ if header.IsV6MulticastAddress(targetAddr) {
received.Invalid.Increment()
return
}
- targetAddr := ns.TargetAddress()
s := r.Stack()
if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
// We will only get an error if the NIC is unrecognized, which should not
@@ -187,39 +239,22 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// so the packet is processed as defined in RFC 4861, as per RFC 4862
// section 5.4.3.
- // Is the NS targetting us?
- if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ // Is the NS targeting us?
+ if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
return
}
- // If the NS message contains the Source Link-Layer Address option, update
- // the link address cache with the value of the option.
- //
- // TODO(b/148429853): Properly process the NS message and do Neighbor
- // Unreachability Detection.
- var sourceLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
- }
+ it, err := ns.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
- switch opt := opt.(type) {
- case header.NDPSourceLinkLayerAddressOption:
- // No RFCs define what to do when an NS message has multiple Source
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(sourceLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
-
- sourceLinkAddr = opt.EthernetAddress()
- }
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
}
unspecifiedSource := r.RemoteAddress == header.IPv6Any
@@ -237,6 +272,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if unspecifiedSource {
received.Invalid.Increment()
return
+ } else if e.nud != nil {
+ e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
}
@@ -304,7 +341,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize {
received.Invalid.Increment()
return
}
@@ -314,17 +351,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// 5, NDP messages cannot be fragmented. Also note that in the common case
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
- it, err := na.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NA option, drop the packet.
- received.Invalid.Increment()
- return
- }
-
targetAddr := na.TargetAddress()
- stack := r.Stack()
+ s := r.Stack()
- if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
// We will only get an error if the NIC is unrecognized, which should not
// happen. For now short-circuit this packet.
//
@@ -335,7 +365,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- stack.DupTentativeAddrDetected(e.nicID, targetAddr)
+ s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ return
+ }
+
+ it, err := na.Options().Iter(false /* check */)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
return
}
@@ -348,39 +385,25 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// TODO(b/143147598): Handle the scenario described above. Also inform the
// netstack integration that a duplicate address was detected outside of
// DAD.
+ targetLinkAddr, ok := getTargetLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
- //
- // TODO(b/148429853): Properly process the NA message and do Neighbor
- // Unreachability Detection.
- var targetLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
+ if len(targetLinkAddr) != 0 {
+ if e.nud == nil {
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ return
}
- switch opt := opt.(type) {
- case header.NDPTargetLinkLayerAddressOption:
- // No RFCs define what to do when an NA message has multiple Target
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(targetLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
-
- targetLinkAddr = opt.EthernetAddress()
- }
- }
-
- if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ Solicited: na.SolicitedFlag(),
+ Override: na.OverrideFlag(),
+ IsRouter: na.RouterFlag(),
+ })
}
case header.ICMPv6EchoRequest:
@@ -440,27 +463,75 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6RouterSolicit:
received.RouterSolicit.Increment()
- if !isNDPValid() {
+
+ //
+ // Validate the RS as per RFC 4861 section 6.1.1.
+ //
+
+ // Is the NDP payload of sufficient size to hold a Router Solictation?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
received.Invalid.Increment()
return
}
- case header.ICMPv6RouterAdvert:
- received.RouterAdvert.Increment()
+ stack := r.Stack()
- // Is the NDP payload of sufficient size to hold a Router
- // Advertisement?
- if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ // Is the networking stack operating as a router?
+ if !stack.Forwarding() {
+ // ... No, silently drop the packet.
+ received.RouterOnlyPacketsDroppedByHost.Increment()
+ return
+ }
+
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
+ rs := header.NDPRouterSolicit(payload.ToView())
+ it, err := rs.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
received.Invalid.Increment()
return
}
- routerAddr := iph.SourceAddress()
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+
+ // If the RS message has the source link layer option, update the link
+ // address cache with the link address for the source of the message.
+ if len(sourceLinkAddr) != 0 {
+ // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, it SHOULD be included on link layers that have addresses.
+ if r.RemoteAddress == header.IPv6Any {
+ received.Invalid.Increment()
+ return
+ }
+
+ if e.nud != nil {
+ // A RS with a specified source IP address modifies the NUD state
+ // machine in the same way a reachability probe would.
+ e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+ }
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
//
// Validate the RA as per RFC 4861 section 6.1.2.
//
+ // Is the NDP payload of sufficient size to hold a Router Advertisement?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ routerAddr := iph.SourceAddress()
+
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
// ...No, silently drop the packet.
@@ -468,16 +539,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- // The remainder of payload must be only the router advertisement, so
- // payload.ToView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and ToView() will not incur allocations.
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
ra := header.NDPRouterAdvert(payload.ToView())
- opts := ra.Options()
+ it, err := ra.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
- // Are options valid as per the wire format?
- if _, err := opts.Iter(true); err != nil {
- // ...No, silently drop the packet.
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
received.Invalid.Increment()
return
}
@@ -487,12 +560,33 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// as RFC 4861 section 6.1.2 is concerned.
//
+ // If the RA has the source link layer option, update the link address
+ // cache with the link address for the advertised router.
+ if len(sourceLinkAddr) != 0 && e.nud != nil {
+ e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+
// Tell the NIC to handle the RA.
stack := r.Stack()
- rxNICID := r.NICID()
- stack.HandleNDPRA(rxNICID, routerAddr, ra)
+ stack.HandleNDPRA(e.nicID, routerAddr, ra)
case header.ICMPv6RedirectMsg:
+ // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating
+ // this redirect message, as per RFC 4871 section 7.3.3:
+ //
+ // "A Neighbor Cache entry enters the STALE state when created as a
+ // result of receiving packets other than solicited Neighbor
+ // Advertisements (i.e., Router Solicitations, Router Advertisements,
+ // Redirects, and Neighbor Solicitations). These packets contain the
+ // link-layer address of either the sender or, in the case of Redirect,
+ // the redirection target. However, receipt of these link-layer
+ // addresses does not confirm reachability of the forward-direction path
+ // to that node. Placing a newly created Neighbor Cache entry for which
+ // the link-layer address is known in the STALE state provides assurance
+ // that path failures are detected quickly. In addition, should a cached
+ // link-layer address be modified due to receiving one of the above
+ // messages, the state SHOULD also be set to STALE to provide prompt
+ // verification that the path to the new link-layer address is working."
received.RedirectMsg.Increment()
if !isNDPValid() {
received.Invalid.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 9e4eeea77..8112ed051 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -31,6 +31,8 @@ import (
)
const (
+ nicID = 1
+
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
@@ -49,7 +51,10 @@ type stubLinkEndpoint struct {
}
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+ // Indicate that resolution for link layer addresses is required to send
+ // packets over this link. This is needed so the NIC knows to allocate a
+ // neighbor table.
+ return stack.CapabilityResolutionRequired
}
func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
@@ -84,16 +89,184 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco
func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
}
+type stubNUDHandler struct{}
+
+var _ stack.NUDHandler = (*stubNUDHandler)(nil)
+
+func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+}
+
+func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+}
+
+func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
+}
+
func TestICMPCounts(t *testing.T) {
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ {
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ defer ep.Close()
+
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ }
+ defer r.Release()
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ }{
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
+ }
+
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(&r, pkt)
+ }
+
+ for _, typ := range types {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
+ }
+
+ // Construct an empty ICMP packet so that
+ // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+
+ icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
+ if got, want := s.Value(), uint64(1); got != want {
+ t.Errorf("got %s = %d, want = %d", name, got, want)
+ }
+ })
+ if t.Failed() {
+ t.Logf("stats:\n%+v", s.Stats())
+ }
+ })
+ }
+}
+
+func TestICMPCountsWithNeighborCache(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ UseNeighborCache: true,
})
{
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
}
@@ -105,7 +278,7 @@ func TestICMPCounts(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -114,12 +287,12 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
defer ep.Close()
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -265,19 +438,19 @@ func newTestContext(t *testing.T) *testContext {
if testing.Verbose() {
wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
+ if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
- if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
+ if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
+ if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
@@ -288,7 +461,7 @@ func newTestContext(t *testing.T) *testContext {
c.s0.SetRouteTable(
[]tcpip.Route{{
Destination: subnet0,
- NIC: 1,
+ NIC: nicID,
}},
)
subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
@@ -298,7 +471,7 @@ func newTestContext(t *testing.T) *testContext {
c.s1.SetRouteTable(
[]tcpip.Route{{
Destination: subnet1,
- NIC: 1,
+ NIC: nicID,
}},
)
@@ -359,9 +532,9 @@ func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
- r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -376,14 +549,14 @@ func TestLinkResolution(t *testing.T) {
var wq waiter.Queue
ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
for {
- _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
+ _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}})
if resCh != nil {
if err != tcpip.ErrNoLinkAddress {
- t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
+ t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err)
}
for _, args := range []routeArgs{
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
@@ -399,7 +572,7 @@ func TestLinkResolution(t *testing.T) {
continue
}
if err != nil {
- t.Fatalf("ep.Write(_) = _, _, %s", err)
+ t.Fatalf("ep.Write(_) = (_, _, %s)", err)
}
break
}
@@ -424,6 +597,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
size int
extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
}{
{
name: "DstUnreachable",
@@ -480,6 +654,8 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
+ // Hosts MUST silently discard any received Router Solicitation messages.
+ routerOnly: true,
},
{
name: "RouterAdvert",
@@ -516,84 +692,133 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- handleIPv6Payload := func(checksum bool) {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- if checksum {
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
+ }
+ t.Run(name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr0)
+
+ // Indicate that resolution for link layer addresses is required to
+ // send packets over this link. This is needed so the NIC knows to
+ // allocate a neighbor table.
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(true)
+ }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(checksum bool) {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ if checksum {
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ }
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
+ })
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Router only count should not have increased.
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ if !isRouter && typ.routerOnly && test.useNeighborCache {
+ // Router only count should have increased.
+ if got := routerOnly.Value(); got != 1 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
+ }
+ }
+ })
}
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
}
})
}
@@ -696,11 +921,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
{
@@ -711,7 +936,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -750,7 +975,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -761,13 +986,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
@@ -874,12 +1099,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -889,7 +1114,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -929,7 +1154,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -940,13 +1165,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 267d2cce8..36fbbebf0 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -48,6 +48,7 @@ type endpoint struct {
nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
dispatcher stack.TransportDispatcher
protocol *protocol
stack *stack.Stack
@@ -455,11 +456,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
nicID: nicID,
linkEP: linkEP,
linkAddrCache: linkAddrCache,
+ nud: nud,
dispatcher: dispatcher,
protocol: p,
stack: st,
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index af71a7d6b..480c495fa 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -18,6 +18,7 @@ import (
"strings"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -30,12 +31,13 @@ import (
// setupStackAndEndpoint creates a stack with a single NIC with a link-local
// address llladdr and an IPv6 endpoint to a remote with link-local address
// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
@@ -63,8 +65,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
-
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
return s, ep
}
@@ -171,6 +172,123 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
}
}
+// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NS message with the Source Link Layer Address
+// option results in a new entry in the link address cache for the sender of
+// the message.
+func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if !ok {
+ t.Fatalf("expected a neighbor entry for %q", lladdr1)
+ }
+ if neigh.LinkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr)
+ }
+ if neigh.State != stack.Stale {
+ t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+
+ if ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+ }
+ })
+ }
+}
+
func TestNeighorSolicitationResponse(t *testing.T) {
const nicID = 1
nicAddr := lladdr0
@@ -180,6 +298,20 @@ func TestNeighorSolicitationResponse(t *testing.T) {
remoteLinkAddr0 := linkAddr1
remoteLinkAddr1 := linkAddr2
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
name string
nsOpts header.NDPOptionsSerializer
@@ -338,86 +470,92 @@ func TestNeighorSolicitationResponse(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(1, 1280, nicLinkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(nicAddr)
- opts := ns.Options()
- opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
- })
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
- if test.nsInvalid {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
- if p, got := e.Read(); got {
- t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
- }
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
- // If we expected the NS to be invalid, we have nothing else to check.
- return
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
- p, got := e.Read()
- if !got {
- t.Fatal("expected an NDP NA response")
- }
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
}
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.naSrc),
- checker.DstAddr(test.naDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNA(
- checker.NDPNASolicitedFlag(test.naSolicited),
- checker.NDPNATargetAddress(nicAddr),
- checker.NDPNAOptions([]header.NDPOption{
- header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
- }),
- ))
})
}
}
@@ -532,197 +670,380 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
}
-func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
- }
-
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
- if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
- Data: payload.ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
- ep.HandlePacket(r, pkt)
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
+// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NA message with the Target Link Layer Address
+// option does not result in a new entry in the neighbor cache for the target
+// of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ tests := []struct {
+ name string
+ optsBuf []byte
+ isValid bool
}{
{
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ isValid: true,
},
{
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
},
{
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
},
{
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
},
},
}
- subTests := []struct {
- name string
- atomicFragment bool
- hopLimit uint8
- code header.ICMPv6Code
- valid bool
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+
+ if test.isValid {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNDPValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
}{
{
- name: "Valid",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: true,
- },
- {
- name: "Fragmented",
- atomicFragment: true,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid hop limit",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit - 1,
- code: 0,
- valid: false,
+ name: "linkAddrCache",
+ useNeighborCache: false,
},
{
- name: "Invalid ICMPv6 code",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 1,
- valid: false,
+ name: "neighborCache",
+ useNeighborCache: true,
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- for _, test := range subTests {
- t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
- // Rx count of the NDP message should initially be 0.
- if got := typStat.Value(); got != 0 {
- t.Errorf("got %s = %d, want = 0", typ.name, got)
- }
+ return s, ep, r
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
+ if atomicFragment {
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
- if t.Failed() {
- t.FailNow()
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ Data: payload.ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload) + len(extensions)),
+ NextHeader: nextHdr,
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, pkt)
+ }
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
- // Rx count of the NDP packet should have increased.
- if got := typStat.Value(); got != 1 {
- t.Errorf("got %s = %d, want = 1", typ.name, got)
- }
+ var sllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
- want := uint64(0)
- if !test.valid {
- // Invalid count should have increased.
- want = 1
- }
- if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
+ }{
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ routerOnly: true,
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ extraData: sllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code header.ICMPv6Code
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
}
- })
+
+ t.Run(name, func(t *testing.T) {
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(true)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ if got := routerOnly.Value(); got != 0 {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+
+ want = 0
+ if test.valid && !isRouter && typ.routerOnly {
+ // RouterOnlyPacketsReceivedByHost count should have increased.
+ want = 1
+ }
+ if got := routerOnly.Value(); got != want {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ }
+
+ })
+ }
+ })
+ }
}
})
}
+
}
// TestRouterAdvertValidation tests that when the NIC is configured to handle
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
func TestRouterAdvertValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
name string
src tcpip.Address
@@ -844,61 +1165,67 @@ func TestRouterAdvertValidation(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
}
})
}