summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/arp/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go58
-rw-r--r--pkg/tcpip/network/arp/arp_test.go331
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD4
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go25
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go57
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go23
-rw-r--r--pkg/tcpip/network/ip_test.go14
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go90
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go749
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go278
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go447
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go185
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go424
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go875
-rw-r--r--pkg/tcpip/network/testutil/BUILD17
-rw-r--r--pkg/tcpip/network/testutil/testutil.go92
19 files changed, 2804 insertions, 870 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index eddf7b725..b40dde96b 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/stack",
],
)
@@ -28,5 +29,6 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 920872c3f..cb9225bd7 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -46,6 +47,7 @@ type endpoint struct {
nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
}
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
@@ -78,7 +80,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return e.protocol.Number()
+ return ProtocolNumber
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
@@ -99,9 +101,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
- return // we have no useful answer, ignore the request
+
+ if e.nud == nil {
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ } else {
+ if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ remoteAddr := tcpip.Address(h.ProtocolAddressSender())
+ remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
})
@@ -113,11 +131,28 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
_ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
- fallthrough // also fill the cache from requests
+
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+
+ if e.nud == nil {
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ return
+ }
+
+ // The solicited, override, and isRouter flags are not available for ARP;
+ // they are only available for IPv6 Neighbor Advertisements.
+ e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{
+ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies
+ // are handled equivalently to a solicited Neighbor Advertisement.
+ Solicited: true,
+ // If a different link address is received than the one cached, the entry
+ // should always go to Stale.
+ Override: false,
+ // ARP does not distinguish between router and non-router hosts.
+ IsRouter: false,
+ })
}
}
@@ -134,12 +169,13 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
protocol: p,
nicID: nicID,
linkEP: sender,
linkAddrCache: linkAddrCache,
+ nud: nud,
}
}
@@ -182,12 +218,12 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
// SetOption implements stack.NetworkProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.NetworkProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -199,11 +235,7 @@ func (*protocol) Wait() {}
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- _, ok = pkt.NetworkHeader().Consume(header.ARPSize)
- if !ok {
- return 0, false, false
- }
- return 0, false, true
+ return 0, false, parse.ARP(pkt)
}
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index c2c3e6891..9c9a859e3 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -16,10 +16,12 @@ package arp_test
import (
"context"
+ "fmt"
"strconv"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -32,57 +34,192 @@ import (
)
const (
- stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+ nicID = 1
+
+ stackAddr = tcpip.Address("\x0a\x00\x00\x01")
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+
+ remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
+ remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+
+ unknownAddr = tcpip.Address("\x0a\x00\x00\x03")
defaultChannelSize = 1
defaultMTU = 65536
+
+ // eventChanSize defines the size of event channels used by the neighbor
+ // cache's event dispatcher. The size chosen here needs to be sufficient to
+ // queue all the events received during tests before consumption.
+ // If eventChanSize is too small, the tests may deadlock.
+ eventChanSize = 32
+)
+
+type eventType uint8
+
+const (
+ entryAdded eventType = iota
+ entryChanged
+ entryRemoved
)
+func (t eventType) String() string {
+ switch t {
+ case entryAdded:
+ return "add"
+ case entryChanged:
+ return "change"
+ case entryRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+type eventInfo struct {
+ eventType eventType
+ nicID tcpip.NICID
+ addr tcpip.Address
+ linkAddr tcpip.LinkAddress
+ state stack.NeighborState
+}
+
+func (e eventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state)
+}
+
+// arpDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type arpDispatcher struct {
+ // C is where events are queued
+ C chan eventInfo
+}
+
+var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
+
+func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryChanged,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryRemoved,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
+ select {
+ case got := <-d.C:
+ if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" {
+ return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
+ }
+ case <-ctx.Done():
+ return fmt.Errorf("%s for %s", ctx.Err(), want)
+ }
+ return nil
+}
+
+func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ return d.waitForEvent(ctx, want)
+}
+
+func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
+ select {
+ case event := <-d.C:
+ return event, true
+ default:
+ return eventInfo{}, false
+ }
+}
+
type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
+ s *stack.Stack
+ linkEP *channel.Endpoint
+ nudDisp *arpDispatcher
}
-func newTestContext(t *testing.T) *testContext {
+func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
+ c := stack.DefaultNUDConfigurations()
+ // Transition from Reachable to Stale almost immediately to test if receiving
+ // probes refreshes positive reachability.
+ c.BaseReachableTime = time.Microsecond
+
+ d := arpDispatcher{
+ // Create an event channel large enough so the neighbor cache doesn't block
+ // while dispatching events. Blocking could interfere with the timing of
+ // NUD transitions.
+ C: make(chan eventInfo, eventChanSize),
+ }
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ NUDConfigs: c,
+ NUDDisp: &d,
+ UseNeighborCache: useNeighborCache,
})
- ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNIC(nicID, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ if !useNeighborCache {
+ // The remote address needs to be assigned to the NIC so we can receive and
+ // verify outgoing ARP packets. The neighbor cache isn't concerned with
+ // this; the tests that use linkAddrCache expect the ARP responses to be
+ // received by the same NIC.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
}
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
t.Fatalf("AddAddress for arp failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
}})
return &testContext{
- t: t,
- s: s,
- linkEP: ep,
+ s: s,
+ linkEP: ep,
+ nudDisp: &d,
}
}
@@ -91,7 +228,7 @@ func (c *testContext) cleanup() {
}
func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
+ c := newTestContext(t, false /* useNeighborCache */)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
@@ -111,7 +248,7 @@ func TestDirectRequest(t *testing.T) {
}))
}
- for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ for i, address := range []tcpip.Address{stackAddr, remoteAddr} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
pi, _ := c.linkEP.ReadContext(context.Background())
@@ -122,7 +259,7 @@ func TestDirectRequest(t *testing.T) {
if !rep.IsValid() {
t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
}
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want {
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
@@ -137,7 +274,7 @@ func TestDirectRequest(t *testing.T) {
})
}
- inject(stackAddrBad)
+ inject(unknownAddr)
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
@@ -148,6 +285,144 @@ func TestDirectRequest(t *testing.T) {
}
}
+func TestDirectRequestWithNeighborCache(t *testing.T) {
+ c := newTestContext(t, true /* useNeighborCache */)
+ defer c.cleanup()
+
+ tests := []struct {
+ name string
+ senderAddr tcpip.Address
+ senderLinkAddr tcpip.LinkAddress
+ targetAddr tcpip.Address
+ isValid bool
+ }{
+ {
+ name: "Loopback",
+ senderAddr: stackAddr,
+ senderLinkAddr: stackLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "Remote",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "RemoteInvalidTarget",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: unknownAddr,
+ isValid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Inject an incoming ARP request.
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), test.senderLinkAddr)
+ copy(h.ProtocolAddressSender(), test.senderAddr)
+ copy(h.ProtocolAddressTarget(), test.targetAddr)
+ c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
+
+ if !test.isValid {
+ // No packets should be sent after receiving an invalid ARP request.
+ // There is no need to perform a blocking read here, since packets are
+ // sent in the same function that handles ARP requests.
+ if pkt, ok := c.linkEP.Read(); ok {
+ t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto)
+ }
+ return
+ }
+
+ // Verify an ARP response was sent.
+ pi, ok := c.linkEP.Read()
+ if !ok {
+ t.Fatal("expected ARP response to be sent, got none")
+ }
+
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
+ }
+ rep := header.ARP(pi.Pkt.NetworkHeader().View())
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want)
+ }
+
+ // Verify the sender was saved in the neighbor cache.
+ wantEvent := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ addr: test.senderAddr,
+ linkAddr: tcpip.LinkAddress(test.senderLinkAddr),
+ state: stack.Stale,
+ }
+ if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
+ t.Fatal(err)
+ }
+
+ neighbors, err := c.s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("c.s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff)
+ }
+ t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ neigh, ok := neighborByAddr[test.senderAddr]
+ if !ok {
+ t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr)
+ }
+ if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
+ t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
+ }
+ if got, want := neigh.LocalAddr, stackAddr; got != want {
+ t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want)
+ }
+ if got, want := neigh.State, stack.Stale; got != want {
+ t.Errorf("got neighbor State = %s, want = %s", got, want)
+ }
+
+ // No more events should be dispatched
+ for {
+ event, ok := c.nudDisp.nextEvent()
+ if !ok {
+ break
+ }
+ t.Errorf("unexpected %s", event)
+ }
+ })
+ }
+}
+
func TestLinkAddressRequest(t *testing.T) {
tests := []struct {
name string
@@ -156,8 +431,8 @@ func TestLinkAddressRequest(t *testing.T) {
}{
{
name: "Unicast",
- remoteLinkAddr: stackLinkAddr2,
- expectLinkAddr: stackLinkAddr2,
+ remoteLinkAddr: remoteLinkAddr,
+ expectLinkAddr: remoteLinkAddr,
},
{
name: "Multicast",
@@ -173,9 +448,9 @@ func TestLinkAddressRequest(t *testing.T) {
t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
}
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
- if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err)
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err)
}
pkt, ok := linkEP.Read()
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index d1c728ccf..96c5f42f8 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -41,5 +41,7 @@ go_test(
"reassembler_test.go",
],
library = ":fragmentation",
- deps = ["//pkg/tcpip/buffer"],
+ deps = [
+ "//pkg/tcpip/buffer",
+ ],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 1827666c5..6a4843f92 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -120,29 +120,36 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
}
// Process processes an incoming fragment belonging to an ID and returns a
-// complete packet when all the packets belonging to that ID have been received.
+// complete packet and its protocol number when all the packets belonging to
+// that ID have been received.
//
// [first, last] is the range of the fragment bytes.
//
// first must be a multiple of the block size f is configured with. The size
// of the fragment data must be a multiple of the block size, unless there are
// no fragments following this fragment (more set to false).
-func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+//
+// proto is the protocol number marked in the fragment being processed. It has
+// to be given here outside of the FragmentID struct because IPv6 should not use
+// the protocol to identify a fragment.
+func (f *Fragmentation) Process(
+ id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (
+ buffer.VectorisedView, uint8, bool, error) {
if first > last {
- return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
}
if first%f.blockSize != 0 {
- return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
}
fragmentSize := last - first + 1
if more && fragmentSize%f.blockSize != 0 {
- return buffer.VectorisedView{}, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
}
if l := vv.Size(); l < int(fragmentSize) {
- return buffer.VectorisedView{}, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
}
vv.CapLength(int(fragmentSize))
@@ -160,14 +167,14 @@ func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv
}
f.mu.Unlock()
- res, done, consumed, err := r.process(first, last, more, vv)
+ res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
f.release(r)
f.mu.Unlock()
- return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
f.size += consumed
@@ -186,7 +193,7 @@ func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv
}
}
f.mu.Unlock()
- return res, done, nil
+ return res, firstFragmentProto, done, nil
}
func (f *Fragmentation) release(r *reassembler) {
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 9eedd33c4..416604659 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -38,12 +38,14 @@ type processInput struct {
first uint16
last uint16
more bool
+ proto uint8
vv buffer.VectorisedView
}
type processOutput struct {
- vv buffer.VectorisedView
- done bool
+ vv buffer.VectorisedView
+ proto uint8
+ done bool
}
var processTestCases = []struct {
@@ -63,6 +65,17 @@ var processTestCases = []struct {
},
},
{
+ comment: "Next Header protocol mismatch",
+ in: []processInput{
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "01", "23"), proto: 6, done: true},
+ },
+ },
+ {
comment: "Two IDs",
in: []processInput{
{id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
@@ -83,18 +96,26 @@ func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout)
+ firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
if err != nil {
- t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
+ in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
}
if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView())
}
if done != c.out[i].done {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
+ in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
}
if c.out[i].done {
+ if firstFragmentProto != proto {
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
+ }
if _, ok := f.reassemblers[in.id]; ok {
t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
}
@@ -113,14 +134,14 @@ func TestReassemblingTimeout(t *testing.T) {
timeout := time.Millisecond
f := NewFragmentation(minBlockSize, 1024, 512, timeout)
// Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Sleep more than the timeout.
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done, err := f.Process(FragmentID{}, 1, 1, false, vv(1, "1"))
+ _, _, done, err := f.Process(FragmentID{}, 1, 1, false, 0xFF, vv(1, "1"))
if err != nil {
- t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ t.Fatalf("f.Process(0, 1, 1, false, 0xFF, vv(1, \"1\")) failed: %v", err)
}
if done {
t.Errorf("Fragmentation does not respect the reassembling timeout.")
@@ -130,15 +151,15 @@ func TestReassemblingTimeout(t *testing.T) {
func TestMemoryLimits(t *testing.T) {
f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout)
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"))
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"))
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -154,9 +175,9 @@ func TestMemoryLimits(t *testing.T) {
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout)
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
got := f.size
want := 1
@@ -248,12 +269,12 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout)
- _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, vv(len(test.data), test.data))
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
if !errors.Is(err, test.err) {
- t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, _, %v), want = (_, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
}
if done {
- t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, true, _), want = (_, false, _)", test.first, test.last, test.more, test.data)
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data)
}
})
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 50d30bbf0..f044867dc 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -34,6 +34,7 @@ type reassembler struct {
reassemblerEntry
id FragmentID
size int
+ proto uint8
mu sync.Mutex
holes []hole
deleted int
@@ -46,7 +47,6 @@ func newReassembler(id FragmentID) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
- deleted: 0,
heap: make(fragHeap, 0, 8),
creationTime: time.Now(),
}
@@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -86,7 +86,18 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, consumed, nil
+ }
+ // For IPv6, it is possible to have different Protocol values between
+ // fragments of a packet (because, unlike IPv4, the Protocol is not used to
+ // identify a fragment). In this case, only the Protocol of the first
+ // fragment must be used as per RFC 8200 Section 4.5.
+ //
+ // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded
+ // here (instead of just the protocol) because most IP options should be
+ // derived from the first fragment.
+ if first == 0 {
+ r.proto = proto
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
@@ -96,13 +107,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.VectorisedView{}, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
+ return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
}
- return res, true, consumed, nil
+ return res, r.proto, true, consumed, nil
}
func (r *reassembler) tooOld(timeout time.Duration) bool {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 9007346fe..e45dd17f8 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -250,7 +250,7 @@ func buildDummyStack(t *testing.T) *stack.Stack {
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, nil, &o, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t))
defer ep.Close()
// Allocate and initialize the payload view.
@@ -287,7 +287,7 @@ func TestIPv4Send(t *testing.T) {
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
totalLen := header.IPv4MinimumSize + 30
@@ -357,7 +357,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
@@ -418,7 +418,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
totalLen := header.IPv4MinimumSize + 24
@@ -495,7 +495,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t))
defer ep.Close()
// Allocate and initialize the payload view.
@@ -532,7 +532,7 @@ func TestIPv6Send(t *testing.T) {
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
totalLen := header.IPv6MinimumSize + 30
@@ -611,7 +611,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index d142b4ffa..f9c2aa980 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -13,6 +13,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
@@ -30,6 +31,7 @@ go_test(
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 79872ec9a..b14b356d6 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -59,7 +60,7 @@ type endpoint struct {
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
nicID: nicID,
linkEP: linkEP,
@@ -235,14 +236,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
ipt := e.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
- // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
- // only NATted packets, but removing this check short circuits broadcasts
- // before they are sent out to other hosts.
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
@@ -297,8 +301,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
return n, err
}
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- // Slow Path as we are dropping some packets in the batch degrade to
+ // Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
@@ -318,12 +323,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
}
n++
}
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, nil
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -392,6 +400,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
return
}
@@ -404,29 +413,35 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < h.FragmentOffset() {
+ start := h.FragmentOffset()
+ // Drop the fragment if the size of the reassembled payload would exceed the
+ // maximum payload size.
+ //
+ // Note that this addition doesn't overflow even on 32bit architecture
+ // because pkt.Data.Size() should not exceed 65535 (the max IP datagram
+ // size). Otherwise the packet would've been rejected as invalid before
+ // reaching here.
+ if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
var ready bool
var err error
- pkt.Data, ready, err = e.protocol.fragmentation.Process(
+ proto := h.Protocol()
+ pkt.Data, _, ready, err = e.protocol.fragmentation.Process(
// As per RFC 791 section 2.3, the identification value is unique
// for a source-destination pair and protocol.
fragmentation.FragmentID{
Source: h.SourceAddress(),
Destination: h.DestinationAddress(),
ID: uint32(h.ID()),
- Protocol: h.Protocol(),
+ Protocol: proto,
},
- h.FragmentOffset(),
- last,
+ start,
+ start+uint16(pkt.Data.Size())-1,
h.More(),
+ proto,
pkt.Data,
)
if err != nil {
@@ -484,10 +499,10 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case tcpip.DefaultTTLOption:
- p.SetDefaultTTL(uint8(v))
+ case *tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(*v))
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -495,7 +510,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
@@ -521,37 +536,14 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return 0, false, false
}
- ipHdr := header.IPv4(hdr)
-
- // Header may have options, determine the true header length.
- headerLen := int(ipHdr.HeaderLength())
- if headerLen < header.IPv4MinimumSize {
- // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
- // order for the packet to be valid. Figure out if we want to reject this
- // case.
- headerLen = header.IPv4MinimumSize
- }
- hdr, ok = pkt.NetworkHeader().Consume(headerLen)
- if !ok {
- return 0, false, false
- }
- ipHdr = header.IPv4(hdr)
-
- // If this is a fragment, don't bother parsing the transport header.
- parseTransportHeader := true
- if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
- parseTransportHeader = false
- }
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
- return ipHdr.TransportProtocol(), parseTransportHeader, true
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 197e3bc51..b14bc98e8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,8 +17,6 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
- "fmt"
- "math/rand"
"testing"
"github.com/google/go-cmp/cmp"
@@ -28,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -92,31 +91,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeRandPkt generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer {
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: hdrLength + extraLength,
- Data: buffer.NewVectorisedView(totalLength, views),
- })
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- return pkt
-}
-
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -186,63 +160,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type errorChannel struct {
- *channel.Endpoint
- Ch chan *stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan *stack.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
-
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- }
- return nextError
-}
-
-type context struct {
+type testRoute struct {
stack.Route
- linkEP *errorChannel
+
+ linkEP *testutil.TestEndpoint
}
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
// Make the packet and write it.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
})
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
+ testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
+ s.CreateNIC(1, testEP)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -262,9 +192,12 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
if err != nil {
t.Fatalf("s.FindRoute got %v, want %v", err, nil)
}
- return context{
+ t.Cleanup(func() {
+ testEP.Close()
+ })
+ return testRoute{
Route: r,
- linkEP: ep,
+ linkEP: testEP,
}
}
@@ -274,13 +207,13 @@ func TestFragmentation(t *testing.T) {
manyPayloadViewsSizes[i] = 7
}
fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ extraHeaderReserveLength int
+ payloadViewsSizes []int
+ expectedFrags int
}{
{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
@@ -295,10 +228,10 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -307,24 +240,13 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []*stack.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
- }
-
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
+ t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- compareFragments(t, results, source, ft.mtu)
+ compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
})
}
}
@@ -335,21 +257,21 @@ func TestFragmentationErrors(t *testing.T) {
fragTests := []struct {
description string
mtu uint32
- hdrLength int
+ transportHeaderLength int
payloadViewsSizes []int
packetCollectorErrors []*tcpip.Error
}{
{"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -364,7 +286,7 @@ func TestFragmentationErrors(t *testing.T) {
if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
t.Errorf("err got %v, want %v", got, want)
}
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
}
})
@@ -372,115 +294,308 @@ func TestFragmentationErrors(t *testing.T) {
}
func TestInvalidFragments(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 6
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ autoChecksum bool // if true, the Checksum field will be overwritten.
+ }
+
// These packets have both IHL and TotalLength set to 0.
- testCases := []struct {
+ tests := []struct {
name string
- packets [][]byte
+ fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
}{
{
- "ihl_totallen_zero_valid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 0,
- },
- {
- "ihl_totallen_zero_invalid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset non-zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments,
+ FragmentOffset: 59776,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
{
- // Total Length of 37(20 bytes IP header + 17 bytes of
- // payload)
- // Frag Offset of 0x1ffe = 8190*8 = 65520
- // Leading to the fragment end to be past 65535.
- "ihl_totallen_valid_invalid_frag_offset_1",
- [][]byte{
- {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
- // The following 3 tests were found by running a fuzzer and were
- // triggering a panic in the IPv4 reassembler code.
{
- "ihl_less_than_ipv4_minimum_size_1",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 17 octets and Fragment offset 65520
+ // Leading to the fragment end to be past 65536.
+ name: "fragment ends past 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 17,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(17),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "ihl_less_than_ipv4_minimum_size_2",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 16 octets and fragment offset 65520
+ // Leading to the fragment end to be exactly 65536.
+ name: "fragment ends exactly at 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(16),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 0,
+ wantMalformedFragments: 0,
},
{
- "ihl_less_than_ipv4_minimum_size_3",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL less than IPv4 minimum size",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 1944,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize - 12,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 2,
+ wantMalformedFragments: 0,
},
{
- "fragment_with_short_total_len_extra_payload",
- [][]byte{
- {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "fragment with short TotalLength and extra payload",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 28816,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 4,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "multiple_fragments_with_more_fragments_set_to_false",
- [][]byte{
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ name: "multiple fragments with More Fragments flag set to false",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 128,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- const nicID tcpip.NICID = 42
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
ipv4.NewProtocol(),
},
})
+ e := channel.New(0, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize + len(f.payload)
+ hdr := buffer.NewPrependable(pktSize)
- var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
- var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
- ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicID, sniffer.New(ep))
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+ copy(ip[header.IPv4MinimumSize:], f.payload)
- for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ if f.autoChecksum {
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ }
+
+ vv := hdr.View().ToVectorisedView()
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
}))
}
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
}
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
}
})
@@ -534,6 +649,9 @@ func TestReceiveFragments(t *testing.T) {
// the fragment block size of 8 (RFC 791 section 3.1 page 14).
ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
+ // Used to test the max reassembled payload length (65,535 octets).
+ ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
+ udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
type fragmentData struct {
srcAddr tcpip.Address
@@ -827,6 +945,28 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: nil,
},
+ {
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
}
for _, test := range tests {
@@ -906,3 +1046,252 @@ func TestReceiveFragments(t *testing.T) {
})
}
}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ linkEP func() stack.LinkEndpoint
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ // Parameterize the tests to run with both WritePacket and WritePackets.
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ rt := buildRoute(t, nil, test.linkEP())
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ s.CreateNIC(1, linkEP)
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ s.AddAddress(1, ipv4.ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return rt
+}
+
+// limitedEP is a link endpoint that writes up to a certain number of packets
+// before returning errors.
+type limitedEP struct {
+ limit int
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*limitedEP) MTU() uint32 {
+ // Give an MTU that won't cause fragmentation for IPv4+UDP.
+ return header.IPv4MinimumSize + header.UDPMinimumSize
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if ep.limit == 0 {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ nWritten := ep.limit
+ if nWritten > pkts.Len() {
+ nWritten = pkts.Len()
+ }
+ ep.limit -= nWritten
+ return nWritten, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*limitedEP) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*limitedEP) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index bcc64994e..cd5fe3ea8 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -13,6 +13,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 66d3a953a..7430b8fcd 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,8 +15,6 @@
package ipv6
import (
- "fmt"
-
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -71,6 +69,59 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
+// getLinkAddrOption searches NDP options for a given link address option using
+// the provided getAddr function as a filter. Returns the link address if
+// found; otherwise, returns the zero link address value. Also returns true if
+// the options are valid as per the wire format, false otherwise.
+func getLinkAddrOption(it header.NDPOptionIterator, getAddr func(header.NDPOption) tcpip.LinkAddress) (tcpip.LinkAddress, bool) {
+ var linkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ return "", false
+ }
+ if done {
+ break
+ }
+ if addr := getAddr(opt); len(addr) != 0 {
+ // No RFCs define what to do when an NDP message has multiple Link-Layer
+ // Address options. Since no interface can have multiple link-layer
+ // addresses, we consider such messages invalid.
+ if len(linkAddr) != 0 {
+ return "", false
+ }
+ linkAddr = addr
+ }
+ }
+ return linkAddr, true
+}
+
+// getSourceLinkAddr searches NDP options for the source link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getSourceLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if src, ok := opt.(header.NDPSourceLinkLayerAddressOption); ok {
+ return src.EthernetAddress()
+ }
+ return ""
+ })
+}
+
+// getTargetLinkAddr searches NDP options for the target link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if dst, ok := opt.(header.NDPTargetLinkLayerAddressOption); ok {
+ return dst.EthernetAddress()
+ }
+ return ""
+ })
+}
+
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
@@ -137,7 +188,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
@@ -147,14 +198,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// NDP messages cannot be fragmented. Also note that in the common case NDP
// datagrams are very small and ToView() will not incur allocations.
ns := header.NDPNeighborSolicit(payload.ToView())
- it, err := ns.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NS option, drop the packet.
+ targetAddr := ns.TargetAddress()
+
+ // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
+ // address.
+ if header.IsV6MulticastAddress(targetAddr) {
received.Invalid.Increment()
return
}
- targetAddr := ns.TargetAddress()
s := r.Stack()
if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
// We will only get an error if the NIC is unrecognized, which should not
@@ -187,39 +239,22 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// so the packet is processed as defined in RFC 4861, as per RFC 4862
// section 5.4.3.
- // Is the NS targetting us?
- if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ // Is the NS targeting us?
+ if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
return
}
- // If the NS message contains the Source Link-Layer Address option, update
- // the link address cache with the value of the option.
- //
- // TODO(b/148429853): Properly process the NS message and do Neighbor
- // Unreachability Detection.
- var sourceLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
- }
+ it, err := ns.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
- switch opt := opt.(type) {
- case header.NDPSourceLinkLayerAddressOption:
- // No RFCs define what to do when an NS message has multiple Source
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(sourceLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
-
- sourceLinkAddr = opt.EthernetAddress()
- }
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
}
unspecifiedSource := r.RemoteAddress == header.IPv6Any
@@ -237,6 +272,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if unspecifiedSource {
received.Invalid.Increment()
return
+ } else if e.nud != nil {
+ e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
}
@@ -304,7 +341,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize {
received.Invalid.Increment()
return
}
@@ -314,17 +351,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// 5, NDP messages cannot be fragmented. Also note that in the common case
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
- it, err := na.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NA option, drop the packet.
- received.Invalid.Increment()
- return
- }
-
targetAddr := na.TargetAddress()
- stack := r.Stack()
+ s := r.Stack()
- if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
// We will only get an error if the NIC is unrecognized, which should not
// happen. For now short-circuit this packet.
//
@@ -335,7 +365,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- stack.DupTentativeAddrDetected(e.nicID, targetAddr)
+ s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ return
+ }
+
+ it, err := na.Options().Iter(false /* check */)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
return
}
@@ -348,39 +385,25 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// TODO(b/143147598): Handle the scenario described above. Also inform the
// netstack integration that a duplicate address was detected outside of
// DAD.
+ targetLinkAddr, ok := getTargetLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
- //
- // TODO(b/148429853): Properly process the NA message and do Neighbor
- // Unreachability Detection.
- var targetLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
+ if len(targetLinkAddr) != 0 {
+ if e.nud == nil {
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ return
}
- switch opt := opt.(type) {
- case header.NDPTargetLinkLayerAddressOption:
- // No RFCs define what to do when an NA message has multiple Target
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(targetLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
-
- targetLinkAddr = opt.EthernetAddress()
- }
- }
-
- if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ Solicited: na.SolicitedFlag(),
+ Override: na.OverrideFlag(),
+ IsRouter: na.RouterFlag(),
+ })
}
case header.ICMPv6EchoRequest:
@@ -440,27 +463,75 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6RouterSolicit:
received.RouterSolicit.Increment()
- if !isNDPValid() {
+
+ //
+ // Validate the RS as per RFC 4861 section 6.1.1.
+ //
+
+ // Is the NDP payload of sufficient size to hold a Router Solictation?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
received.Invalid.Increment()
return
}
- case header.ICMPv6RouterAdvert:
- received.RouterAdvert.Increment()
+ stack := r.Stack()
- // Is the NDP payload of sufficient size to hold a Router
- // Advertisement?
- if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ // Is the networking stack operating as a router?
+ if !stack.Forwarding(ProtocolNumber) {
+ // ... No, silently drop the packet.
+ received.RouterOnlyPacketsDroppedByHost.Increment()
+ return
+ }
+
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
+ rs := header.NDPRouterSolicit(payload.ToView())
+ it, err := rs.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
received.Invalid.Increment()
return
}
- routerAddr := iph.SourceAddress()
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+
+ // If the RS message has the source link layer option, update the link
+ // address cache with the link address for the source of the message.
+ if len(sourceLinkAddr) != 0 {
+ // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, it SHOULD be included on link layers that have addresses.
+ if r.RemoteAddress == header.IPv6Any {
+ received.Invalid.Increment()
+ return
+ }
+
+ if e.nud != nil {
+ // A RS with a specified source IP address modifies the NUD state
+ // machine in the same way a reachability probe would.
+ e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+ }
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
//
// Validate the RA as per RFC 4861 section 6.1.2.
//
+ // Is the NDP payload of sufficient size to hold a Router Advertisement?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ routerAddr := iph.SourceAddress()
+
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
// ...No, silently drop the packet.
@@ -468,16 +539,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- // The remainder of payload must be only the router advertisement, so
- // payload.ToView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and ToView() will not incur allocations.
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
ra := header.NDPRouterAdvert(payload.ToView())
- opts := ra.Options()
+ it, err := ra.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
- // Are options valid as per the wire format?
- if _, err := opts.Iter(true); err != nil {
- // ...No, silently drop the packet.
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
received.Invalid.Increment()
return
}
@@ -487,12 +560,33 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// as RFC 4861 section 6.1.2 is concerned.
//
+ // If the RA has the source link layer option, update the link address
+ // cache with the link address for the advertised router.
+ if len(sourceLinkAddr) != 0 && e.nud != nil {
+ e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+
// Tell the NIC to handle the RA.
stack := r.Stack()
- rxNICID := r.NICID()
- stack.HandleNDPRA(rxNICID, routerAddr, ra)
+ stack.HandleNDPRA(e.nicID, routerAddr, ra)
case header.ICMPv6RedirectMsg:
+ // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating
+ // this redirect message, as per RFC 4871 section 7.3.3:
+ //
+ // "A Neighbor Cache entry enters the STALE state when created as a
+ // result of receiving packets other than solicited Neighbor
+ // Advertisements (i.e., Router Solicitations, Router Advertisements,
+ // Redirects, and Neighbor Solicitations). These packets contain the
+ // link-layer address of either the sender or, in the case of Redirect,
+ // the redirection target. However, receipt of these link-layer
+ // addresses does not confirm reachability of the forward-direction path
+ // to that node. Placing a newly created Neighbor Cache entry for which
+ // the link-layer address is known in the STALE state provides assurance
+ // that path failures are detected quickly. In addition, should a cached
+ // link-layer address be modified due to receiving one of the above
+ // messages, the state SHOULD also be set to STALE to provide prompt
+ // verification that the path to the new link-layer address is working."
received.RedirectMsg.Increment()
if !isNDPValid() {
received.Invalid.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 9e4eeea77..0f50bfb8e 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -31,6 +31,8 @@ import (
)
const (
+ nicID = 1
+
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
@@ -49,7 +51,10 @@ type stubLinkEndpoint struct {
}
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+ // Indicate that resolution for link layer addresses is required to send
+ // packets over this link. This is needed so the NIC knows to allocate a
+ // neighbor table.
+ return stack.CapabilityResolutionRequired
}
func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
@@ -84,16 +89,184 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco
func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
}
+type stubNUDHandler struct{}
+
+var _ stack.NUDHandler = (*stubNUDHandler)(nil)
+
+func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+}
+
+func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+}
+
+func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
+}
+
func TestICMPCounts(t *testing.T) {
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ {
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ defer ep.Close()
+
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ }
+ defer r.Release()
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ }{
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
+ }
+
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(&r, pkt)
+ }
+
+ for _, typ := range types {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
+ }
+
+ // Construct an empty ICMP packet so that
+ // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+
+ icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
+ if got, want := s.Value(), uint64(1); got != want {
+ t.Errorf("got %s = %d, want = %d", name, got, want)
+ }
+ })
+ if t.Failed() {
+ t.Logf("stats:\n%+v", s.Stats())
+ }
+ })
+ }
+}
+
+func TestICMPCountsWithNeighborCache(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ UseNeighborCache: true,
})
{
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
}
@@ -105,7 +278,7 @@ func TestICMPCounts(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -114,12 +287,12 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
defer ep.Close()
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -265,19 +438,19 @@ func newTestContext(t *testing.T) *testContext {
if testing.Verbose() {
wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
+ if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
- if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
+ if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
+ if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
@@ -288,7 +461,7 @@ func newTestContext(t *testing.T) *testContext {
c.s0.SetRouteTable(
[]tcpip.Route{{
Destination: subnet0,
- NIC: 1,
+ NIC: nicID,
}},
)
subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
@@ -298,7 +471,7 @@ func newTestContext(t *testing.T) *testContext {
c.s1.SetRouteTable(
[]tcpip.Route{{
Destination: subnet1,
- NIC: 1,
+ NIC: nicID,
}},
)
@@ -359,9 +532,9 @@ func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
- r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -376,14 +549,14 @@ func TestLinkResolution(t *testing.T) {
var wq waiter.Queue
ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
for {
- _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
+ _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}})
if resCh != nil {
if err != tcpip.ErrNoLinkAddress {
- t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
+ t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err)
}
for _, args := range []routeArgs{
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
@@ -399,7 +572,7 @@ func TestLinkResolution(t *testing.T) {
continue
}
if err != nil {
- t.Fatalf("ep.Write(_) = _, _, %s", err)
+ t.Fatalf("ep.Write(_) = (_, _, %s)", err)
}
break
}
@@ -424,6 +597,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
size int
extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
}{
{
name: "DstUnreachable",
@@ -480,6 +654,8 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
+ // Hosts MUST silently discard any received Router Solicitation messages.
+ routerOnly: true,
},
{
name: "RouterAdvert",
@@ -516,84 +692,133 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- handleIPv6Payload := func(checksum bool) {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- if checksum {
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
+ }
+ t.Run(name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr0)
+
+ // Indicate that resolution for link layer addresses is required to
+ // send packets over this link. This is needed so the NIC knows to
+ // allocate a neighbor table.
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(checksum bool) {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ if checksum {
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ }
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
+ })
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Router only count should not have increased.
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ if !isRouter && typ.routerOnly && test.useNeighborCache {
+ // Router only count should have increased.
+ if got := routerOnly.Value(); got != 1 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
+ }
+ }
+ })
}
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
}
})
}
@@ -696,11 +921,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
{
@@ -711,7 +936,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -750,7 +975,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -761,13 +986,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
@@ -874,12 +1099,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -889,7 +1114,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -929,7 +1154,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -940,13 +1165,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 0eafe9790..ee64d92d8 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -48,6 +49,7 @@ type endpoint struct {
nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
dispatcher stack.TransportDispatcher
protocol *protocol
stack *stack.Stack
@@ -106,6 +108,32 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
+
if r.Loop&stack.PacketLoop != 0 {
loopedR := r.MakeLoopedRoute()
@@ -120,8 +148,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+ return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -137,9 +168,50 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
e.addIPHeader(r, pb, params)
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ ipt := e.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+
+ // Slow path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
+ }
+ n++
+ }
+
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
@@ -168,6 +240,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
+ return
+ }
+
for firstHeader := true; ; firstHeader = false {
extHdr, done, err := it.Next()
if err != nil {
@@ -310,21 +391,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// The packet is a fragment, let's try to reassemble it.
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- last := start + uint16(fragmentPayloadLen) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < start {
+ // Drop the fragment if the size of the reassembled payload would exceed
+ // the maximum payload size.
+ if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
- var ready bool
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
- pkt.Data, ready, err = e.protocol.fragmentation.Process(
+ data, proto, ready, err := e.protocol.fragmentation.Process(
// IPv6 ignores the Protocol field since the ID only needs to be unique
// across source-destination pairs, as per RFC 8200 section 4.5.
fragmentation.FragmentID{
@@ -333,8 +411,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ID: extHdr.ID(),
},
start,
- last,
+ start+uint16(fragmentPayloadLen)-1,
extHdr.More(),
+ uint8(rawPayload.Identifier),
rawPayload.Buf,
)
if err != nil {
@@ -342,12 +421,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
+ pkt.Data = data
if ready {
// We create a new iterator with the reassembled packet because we could
// have more extension headers in the reassembled payload, as per RFC
- // 8200 section 4.5.
- it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ // 8200 section 4.5. We also use the NextHeader value from the first
+ // fragment.
+ it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data)
}
case header.IPv6DestinationOptionsExtHdr:
@@ -453,11 +534,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
nicID: nicID,
linkEP: linkEP,
linkAddrCache: linkAddrCache,
+ nud: nud,
dispatcher: dispatcher,
protocol: p,
stack: st,
@@ -465,10 +547,10 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddres
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case tcpip.DefaultTTLOption:
- p.SetDefaultTTL(uint8(v))
+ case *tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(*v))
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -476,7 +558,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
@@ -502,75 +584,14 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return 0, false, false
}
- ipHdr := header.IPv6(hdr)
-
- // dataClone consists of:
- // - Any IPv6 header bytes after the first 40 (i.e. extensions).
- // - The transport header, if present.
- // - Any other payload data.
- views := [8]buffer.View{}
- dataClone := pkt.Data.Clone(views[:])
- dataClone.TrimFront(header.IPv6MinimumSize)
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
-
- // Iterate over the IPv6 extensions to find their length.
- //
- // Parsing occurs again in HandlePacket because we don't track the
- // extensions in PacketBuffer. Unfortunately, that means HandlePacket
- // has to do the parsing work again.
- var nextHdr tcpip.TransportProtocolNumber
- foundNext := true
- extensionsSize := 0
-traverseExtensions:
- for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
- if err != nil {
- break
- }
- // If we exhaust the extension list, the entire packet is the IPv6 header
- // and (possibly) extensions.
- if done {
- extensionsSize = dataClone.Size()
- foundNext = false
- break
- }
-
- switch extHdr := extHdr.(type) {
- case header.IPv6FragmentExtHdr:
- // If this is an atomic fragment, we don't have to treat it specially.
- if !extHdr.More() && extHdr.FragmentOffset() == 0 {
- continue
- }
- // This is a non-atomic fragment and has to be re-assembled before we can
- // examine the payload for a transport header.
- foundNext = false
-
- case header.IPv6RawPayloadHeader:
- // We've found the payload after any extensions.
- extensionsSize = dataClone.Size() - extHdr.Buf.Size()
- nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
- break traverseExtensions
-
- default:
- // Any other extension is a no-op, keep looping until we find the payload.
- }
- }
-
- // Put the IPv6 header with extensions in pkt.NetworkHeader().
- hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
- if !ok {
- panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
- }
- ipHdr = header.IPv6(hdr)
- pkt.Data.CapLength(int(ipHdr.PayloadLength()))
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
- return nextHdr, foundNext, true
+ return proto, !fragMore && fragOffset == 0, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 0a183bfde..9eea1de8d 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
@@ -687,6 +688,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
@@ -731,6 +733,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
+ var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte
+ udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:]
+ ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2)
+
tests := []struct {
name string
expectedPayload []byte
@@ -866,6 +872,46 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
+ name: "Two fragments with different Next Header values",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ // NextHeader value is different than the one in the first fragment, so
+ // this NextHeader should be ignored.
+ buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
name: "Two fragments with last fragment size not a multiple of fragment block size",
fragments: []fragmentData{
{
@@ -980,6 +1026,44 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: nil,
},
{
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:65520],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[65520:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -1532,3 +1616,343 @@ func TestReceiveIPv6Fragments(t *testing.T) {
})
}
}
+
+func TestInvalidIPv6Fragments(t *testing.T) {
+ const (
+ nicID = 1
+ fragmentExtHdrLen = 8
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16,
+ []buffer.View{
+ // Fragment extension header.
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8,
+ 0, 0, 0, 1}),
+ // Payload length = 16
+ payloadGen(16),
+ },
+ ),
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ NewProtocol(),
+ },
+ })
+ e := channel.New(0, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
+ t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
+ t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ linkEP func() stack.LinkEndpoint
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ rt := buildRoute(t, nil, test.linkEP())
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ s.CreateNIC(1, linkEP)
+ const (
+ src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ )
+ s.AddAddress(1, ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return rt
+}
+
+// limitedEP is a link endpoint that writes up to a certain number of packets
+// before returning errors.
+type limitedEP struct {
+ limit int
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*limitedEP) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if ep.limit == 0 {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ nWritten := ep.limit
+ if nWritten > pkts.Len() {
+ nWritten = pkts.Len()
+ }
+ ep.limit -= nWritten
+ return nWritten, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*limitedEP) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*limitedEP) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index af71a7d6b..7434df4a1 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -18,6 +18,7 @@ import (
"strings"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -30,12 +31,13 @@ import (
// setupStackAndEndpoint creates a stack with a single NIC with a link-local
// address llladdr and an IPv6 endpoint to a remote with link-local address
// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
@@ -63,8 +65,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
-
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
return s, ep
}
@@ -171,6 +172,123 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
}
}
+// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NS message with the Source Link Layer Address
+// option results in a new entry in the link address cache for the sender of
+// the message.
+func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if !ok {
+ t.Fatalf("expected a neighbor entry for %q", lladdr1)
+ }
+ if neigh.LinkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr)
+ }
+ if neigh.State != stack.Stale {
+ t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+
+ if ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+ }
+ })
+ }
+}
+
func TestNeighorSolicitationResponse(t *testing.T) {
const nicID = 1
nicAddr := lladdr0
@@ -180,6 +298,20 @@ func TestNeighorSolicitationResponse(t *testing.T) {
remoteLinkAddr0 := linkAddr1
remoteLinkAddr1 := linkAddr2
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
name string
nsOpts header.NDPOptionsSerializer
@@ -338,86 +470,92 @@ func TestNeighorSolicitationResponse(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(1, 1280, nicLinkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(nicAddr)
- opts := ns.Options()
- opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
- })
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
- if test.nsInvalid {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
- if p, got := e.Read(); got {
- t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
- }
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
- // If we expected the NS to be invalid, we have nothing else to check.
- return
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
- p, got := e.Read()
- if !got {
- t.Fatal("expected an NDP NA response")
- }
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
}
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.naSrc),
- checker.DstAddr(test.naDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNA(
- checker.NDPNASolicitedFlag(test.naSolicited),
- checker.NDPNATargetAddress(nicAddr),
- checker.NDPNAOptions([]header.NDPOption{
- header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
- }),
- ))
})
}
}
@@ -532,197 +670,380 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
}
-func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
- }
-
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
- if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
- Data: payload.ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
- ep.HandlePacket(r, pkt)
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
+// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NA message with the Target Link Layer Address
+// option does not result in a new entry in the neighbor cache for the target
+// of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ tests := []struct {
+ name string
+ optsBuf []byte
+ isValid bool
}{
{
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ isValid: true,
},
{
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
},
{
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
},
{
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
},
},
}
- subTests := []struct {
- name string
- atomicFragment bool
- hopLimit uint8
- code header.ICMPv6Code
- valid bool
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+
+ if test.isValid {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNDPValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
}{
{
- name: "Valid",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: true,
- },
- {
- name: "Fragmented",
- atomicFragment: true,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid hop limit",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit - 1,
- code: 0,
- valid: false,
+ name: "linkAddrCache",
+ useNeighborCache: false,
},
{
- name: "Invalid ICMPv6 code",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 1,
- valid: false,
+ name: "neighborCache",
+ useNeighborCache: true,
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- for _, test := range subTests {
- t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
- // Rx count of the NDP message should initially be 0.
- if got := typStat.Value(); got != 0 {
- t.Errorf("got %s = %d, want = 0", typ.name, got)
- }
+ return s, ep, r
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
+ if atomicFragment {
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
- if t.Failed() {
- t.FailNow()
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ Data: payload.ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload) + len(extensions)),
+ NextHeader: nextHdr,
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, pkt)
+ }
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
- // Rx count of the NDP packet should have increased.
- if got := typStat.Value(); got != 1 {
- t.Errorf("got %s = %d, want = 1", typ.name, got)
- }
+ var sllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
- want := uint64(0)
- if !test.valid {
- // Invalid count should have increased.
- want = 1
- }
- if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
+ }{
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ routerOnly: true,
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ extraData: sllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code header.ICMPv6Code
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
}
- })
+
+ t.Run(name, func(t *testing.T) {
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ if got := routerOnly.Value(); got != 0 {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+
+ want = 0
+ if test.valid && !isRouter && typ.routerOnly {
+ // RouterOnlyPacketsReceivedByHost count should have increased.
+ want = 1
+ }
+ if got := routerOnly.Value(); got != want {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ }
+
+ })
+ }
+ })
+ }
}
})
}
+
}
// TestRouterAdvertValidation tests that when the NIC is configured to handle
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
func TestRouterAdvertValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
name string
src tcpip.Address
@@ -844,61 +1165,67 @@ func TestRouterAdvertValidation(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
}
})
}
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
new file mode 100644
index 000000000..e218563d0
--- /dev/null
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ srcs = [
+ "testutil.go",
+ ],
+ visibility = ["//pkg/tcpip/network/ipv4:__pkg__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
new file mode 100644
index 000000000..bf5ce74be
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// TestEndpoint is an endpoint used for testing, it stores packets written to it
+// and can mock errors.
+type TestEndpoint struct {
+ *channel.Endpoint
+
+ // WrittenPackets is where we store packets written via WritePacket().
+ WrittenPackets []*stack.PacketBuffer
+
+ packetCollectorErrors []*tcpip.Error
+}
+
+// NewTestEndpoint creates a new TestEndpoint endpoint.
+//
+// packetCollectorErrors can be used to set error values and each call to
+// WritePacket will remove the first one from the slice and return it until
+// the slice is empty - at that point it will return nil every time.
+func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint {
+ return &TestEndpoint{
+ Endpoint: ep,
+ WrittenPackets: make([]*stack.PacketBuffer, 0),
+ packetCollectorErrors: packetCollectorErrors,
+ }
+}
+
+// WritePacket stores outbound packets and may return an error if one was
+// injected.
+func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.WrittenPackets = append(e.WrittenPackets, pkt)
+
+ if len(e.packetCollectorErrors) > 0 {
+ nextError := e.packetCollectorErrors[0]
+ e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ return nextError
+ }
+
+ return nil
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+ var views buffer.VectorisedView
+
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ if _, err := rand.Read(newView); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ views.AppendView(newView)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+ Data: views,
+ })
+ pkt.NetworkProtocolNumber = proto
+ if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
+}