summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/checker/checker.go18
-rw-r--r--pkg/tcpip/header/icmpv6.go5
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/ip_test.go405
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go40
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go62
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go160
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go254
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go44
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go98
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go6
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go277
-rw-r--r--pkg/tcpip/stack/route.go6
-rw-r--r--pkg/tcpip/stack/stack_test.go47
-rw-r--r--pkg/tcpip/tcpip.go2
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go1
18 files changed, 1275 insertions, 156 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index d4d785cca..6f81b0164 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker {
}
}
+// IPPayload creates a checker that checks the payload.
+func IPPayload(payload []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ got := h[0].Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(got) == 0 && len(payload) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(payload, got); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// IPv4Options returns a checker that checks the options in an IPv4 packet.
func IPv4Options(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 6be31beeb..4303fc5d5 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -49,11 +49,6 @@ const (
// neighbor advertisement packet.
ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
- // ICMPv6NeighborAdvertSize is size of a neighbor advertisement
- // including the NDP Target Link Layer option for an Ethernet
- // address.
- ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
-
// ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 59710352b..c118a2929 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -12,6 +12,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index d436873b6..c1848a19c 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,11 +15,13 @@
package ip_test
import (
+ "strings"
"testing"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -1020,3 +1022,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer
_, _ = pkt.NetworkHeader().Consume(netHdrLen)
return pkt
}
+
+func TestWriteHeaderIncludedPacket(t *testing.T) {
+ const (
+ nicID = 1
+ transportProto = 5
+
+ dataLen = 4
+ optionsLen = 4
+ )
+
+ dataBuf := [dataLen]byte{1, 2, 3, 4}
+ data := dataBuf[:]
+
+ ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1}
+ ipv4Options := ipv4OptionsBuf[:]
+
+ ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
+ ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
+
+ var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
+ ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
+ if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ nicAddr tcpip.Address
+ remoteAddr tcpip.Address
+ pktGen func(*testing.T, tcpip.Address) buffer.View
+ checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
+ expectedErr *tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with IHL too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 1,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 minimum size",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(header.IPv4MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ totalLen := ipHdrLen + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(ipHdrLen))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHdrLen),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
+ }
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6 with extension header",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
+ checker.IPPayload(ipv6PayloadWithExtHdr),
+ )
+ },
+ },
+ {
+ name: "IPv6 minimum size",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(header.IPv6MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv6 too small",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ srcAddr tcpip.Address
+ }{
+ {
+ name: "unspecified source",
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ },
+ {
+ name: "random source",
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ e := channel.New(1, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
+
+ r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ }
+ defer r.Release()
+
+ if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ })); err != test.expectedErr {
+ t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a packet to be written")
+ }
+ test.checker(t, pkt.Pkt, subTest.srcAddr)
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c5ac7b8b5..7e2e53523 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -237,7 +237,10 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -347,30 +350,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacket writes a packet already containing a network
-// header through the given route.
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrInvalidOptionValue
+ return tcpip.ErrMalformedHeader
}
ip := header.IPv4(h)
- if !ip.IsValid(pkt.Data.Size()) {
- return tcpip.ErrInvalidOptionValue
- }
// Always set the total length.
- ip.SetTotalLength(uint16(pkt.Data.Size()))
+ pktSize := pkt.Data.Size()
+ ip.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
- if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ if ip.SourceAddress() == header.IPv4Any {
ip.SetSourceAddress(r.LocalAddress)
}
- // Set the destination. If the packet already included a destination,
- // it will be part of the route.
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
ip.SetDestinationAddress(r.RemoteAddress)
// Set the packet ID when zero.
@@ -387,19 +387,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if r.Loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, pkt.Clone())
- }
- if r.Loop&stack.PacketOut == 0 {
- return nil
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
}
- if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- return nil
+ return e.writePacket(r, nil /* gso */, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 9916d783f..819b5d71f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,9 +15,9 @@
package ipv4_test
import (
- "bytes"
"context"
"encoding/hex"
+ "fmt"
"math"
"net"
"testing"
@@ -243,7 +243,7 @@ func TestIPv4Sanity(t *testing.T) {
// Default routes for IPv4 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv4EmptySubnet,
NIC: nicID,
},
@@ -369,11 +369,10 @@ func TestIPv4Sanity(t *testing.T) {
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
- t.Helper()
- // Make a complete array of the sourcePacketInfo packet.
- source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize])
- vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views())
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32) error {
+ // Make a complete array of the sourcePacket packet.
+ source := header.IPv4(packets[0].NetworkHeader().View())
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
source = append(source, vv.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
@@ -384,46 +383,49 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
sourceCopy.SetTotalLength(0)
var offset uint16
// Build up an array of the bytes sent.
- var reassembledPayload []byte
+ var reassembledPayload buffer.VectorisedView
for i, packet := range packets {
// Confirm that the packet is valid.
allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
- ip := header.IPv4(allBytes.ToView())
- if !ip.IsValid(len(ip)) {
- t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ fragmentIPHeader := header.IPv4(allBytes.ToView())
+ if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader))
}
- if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
- t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader.CalculateChecksum() got %#x, want %#x", i, got, want)
}
- if got, want := len(ip), int(mtu); got > want {
- t.Errorf("fragment is too large, got %d want %d", got, want)
+ if got := len(fragmentIPHeader); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu)
}
- if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
- t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ if got, want := packet.AvailableHeaderBytes(), sourcePacket.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
+ return fmt.Errorf("fragment #%d: should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
- if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want {
- t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want)
+ if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
+ return fmt.Errorf("fragment #%d: has wrong network protocol number: got %d, want %d", i, got, want)
}
if i < len(packets)-1 {
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
} else {
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
}
- reassembledPayload = append(reassembledPayload, ip.Payload()...)
- offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ reassembledPayload.AppendView(packet.TransportHeader().View())
+ reassembledPayload.Append(packet.Data)
+ offset += fragmentIPHeader.TotalLength() - uint16(fragmentIPHeader.HeaderLength())
// Clear out the checksum and length from the ip because we can't compare
// it.
- sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetTotalLength(uint16(len(fragmentIPHeader)))
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
- if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
- t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader[:fragmentIPHeader.HeaderLength()] mismatch (-want +got):\n%s", i, diff)
}
}
- expected := source[source.HeaderLength():]
- if !bytes.Equal(reassembledPayload, expected) {
- t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+ expected := buffer.View(source[source.HeaderLength():])
+ if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
}
+
+ return nil
}
func TestFragmentation(t *testing.T) {
@@ -477,7 +479,9 @@ func TestFragmentation(t *testing.T) {
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- compareFragments(t, ep.WrittenPackets, source, ft.mtu)
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu); err != nil {
+ t.Error(err)
+ }
})
}
}
@@ -1633,7 +1637,7 @@ func TestPacketQueing(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
NIC: nicID,
},
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 7be35c78b..ead6bedcb 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -252,26 +252,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- it, err := ns.Options().Iter(false /* check */)
- if err != nil {
- // Options are not valid as per the wire format, silently drop the packet.
- received.Invalid.Increment()
- return
- }
+ var sourceLinkAddr tcpip.LinkAddress
+ {
+ it, err := ns.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the
+ // packet.
+ received.Invalid.Increment()
+ return
+ }
- sourceLinkAddr, ok := getSourceLinkAddr(it)
- if !ok {
- received.Invalid.Increment()
- return
+ sourceLinkAddr, ok = getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
}
- unspecifiedSource := r.RemoteAddress == header.IPv6Any
-
// As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
// NOT be included when the source IP address is the unspecified address.
// Otherwise, on link layers that have addresses this option MUST be
// included in multicast solicitations and SHOULD be included in unicast
// solicitations.
+ unspecifiedSource := r.RemoteAddress == header.IPv6Any
if len(sourceLinkAddr) == 0 {
if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
received.Invalid.Increment()
@@ -297,54 +300,71 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- // ICMPv6 Neighbor Solicit messages are always sent to
- // specially crafted IPv6 multicast addresses. As a result, the
- // route we end up with here has as its LocalAddress such a
- // multicast address. It would be nonsense to claim that our
- // source address is a multicast address, so we manually set
- // the source address to the target address requested in the
- // solicit message. Since that requires mutating the route, we
- // must first clone it.
- r := r.Clone()
- defer r.Release()
- r.LocalAddress = targetAddr
-
- // As per RFC 4861 section 7.2.4, if the the source of the solicitation is
- // the unspecified address, the node MUST set the Solicited flag to zero and
- // multicast the advertisement to the all-nodes address.
- solicited := true
+ // As per RFC 4861 section 7.2.4:
+ //
+ // If the source of the solicitation is the unspecified address, the node
+ // MUST [...] and multicast the advertisement to the all-nodes address.
+ //
+ remoteAddr := r.RemoteAddress
if unspecifiedSource {
- solicited = false
- r.RemoteAddress = header.IPv6AllNodesMulticastAddress
+ remoteAddr = header.IPv6AllNodesMulticastAddress
}
- // If the NS has a source link-layer option, use the link address it
- // specifies as the remote link address for the response instead of the
- // source link address of the packet.
+ // Even if we were able to receive a packet from some remote, we may not
+ // have a route to it - the remote may be blocked via routing rules. We must
+ // always consult our routing table and find a route to the remote before
+ // sending any packet.
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), targetAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ // If the NS has a source link-layer option, resolve the route immediately
+ // to avoid querying the neighbor table when the neighbor entry was updated
+ // as probing the neighbor table for a link address will transition the
+ // entry's state from stale to delay.
+ //
+ // Note, if the source link address is unspecified and this is a unicast
+ // solicitation, we may need to perform neighbor discovery to send the
+ // neighbor advertisement response. This is expected as per RFC 4861 section
+ // 7.2.4:
+ //
+ // Because unicast Neighbor Solicitations are not required to include a
+ // Source Link-Layer Address, it is possible that a node sending a
+ // solicited Neighbor Advertisement does not have a corresponding link-
+ // layer address for its neighbor in its Neighbor Cache. In such
+ // situations, a node will first have to use Neighbor Discovery to
+ // determine the link-layer address of its neighbor (i.e., send out a
+ // multicast Neighbor Solicitation).
//
- // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link
- // address cache for the right destination link address instead of manually
- // patching the route with the remote link address if one is specified in a
- // Source Link-Layer Address option.
if len(sourceLinkAddr) != 0 {
- r.RemoteLinkAddress = sourceLinkAddr
+ r.ResolveWith(sourceLinkAddr)
}
optsSerializer := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
+ header.NDPTargetLinkLayerAddressOption(e.nic.LinkAddress()),
}
+ neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize,
})
- packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
- na.SetSolicitedFlag(solicited)
+
+ // As per RFC 4861 section 7.2.4:
+ //
+ // If the source of the solicitation is the unspecified address, the node
+ // MUST set the Solicited flag to zero and [..]. Otherwise, the node MUST
+ // set the Solicited flag to one and [..].
+ //
+ na.SetSolicitedFlag(!unspecifiedSource)
na.SetOverrideFlag(true)
na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
+ na.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
@@ -361,7 +381,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
received.Invalid.Increment()
return
}
@@ -419,19 +439,19 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
- if len(targetLinkAddr) != 0 {
- if e.nud == nil {
+ if e.nud == nil {
+ if len(targetLinkAddr) != 0 {
e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
- return
}
-
- e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
- Solicited: na.SolicitedFlag(),
- Override: na.OverrideFlag(),
- IsRouter: na.RouterFlag(),
- })
+ return
}
+ e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ Solicited: na.SolicitedFlag(),
+ Override: na.OverrideFlag(),
+ IsRouter: na.RouterFlag(),
+ })
+
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
@@ -620,18 +640,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
}
-const (
- ndpSolicitedFlag = 1 << 6
- ndpOverrideFlag = 1 << 5
-
- ndpOptSrcLinkAddr = 1
- ndpOptDstLinkAddr = 2
-
- icmpV6FlagOffset = 4
- icmpV6OptOffset = 24
- icmpV6LengthOffset = 25
-)
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -647,6 +655,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
r := stack.Route{
LocalAddress: localAddr,
RemoteAddress: addr,
+ LocalLinkAddress: linkEP.LinkAddress(),
RemoteLinkAddress: remoteLinkAddr,
}
@@ -658,17 +667,20 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
}
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()),
+ }
+ neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize,
})
- icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- icmpHdr.SetType(header.ICMPv6NeighborSolicit)
- copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
- icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
- icmpHdr[icmpV6LengthOffset] = 1
- copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
+ packet.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns.SetTargetAddress(addr)
+ ns.Options().Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(pkt.Size())
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 3affcc4e4..8dc33c560 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -101,14 +101,19 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco
func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
}
-type stubNUDHandler struct{}
+type stubNUDHandler struct {
+ probeCount int
+ confirmationCount int
+}
var _ stack.NUDHandler = (*stubNUDHandler)(nil)
-func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+ s.probeCount++
}
-func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+ s.confirmationCount++
}
func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
@@ -118,6 +123,12 @@ var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
stack.NetworkLinkEndpoint
+
+ linkAddr tcpip.LinkAddress
+}
+
+func (i *testInterface) LinkAddress() tcpip.LinkAddress {
+ return i.linkAddr
}
func (*testInterface) ID() tcpip.NICID {
@@ -1492,3 +1503,240 @@ func TestPacketQueing(t *testing.T) {
})
}
}
+
+func TestCallsToNeighborCache(t *testing.T) {
+ tests := []struct {
+ name string
+ createPacket func() header.ICMPv6
+ multicast bool
+ source tcpip.Address
+ destination tcpip.Address
+ wantProbeCount int
+ wantConfirmationCount int
+ }{
+ {
+ name: "Unicast Neighbor Solicitation without source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ // "The source link-layer address option SHOULD be included in unicast
+ // solicitations." - RFC 4861 section 4.3
+ //
+ // A Neighbor Advertisement needs to be sent in response, but the
+ // Neighbor Cache shouldn't be updated since we have no useful
+ // information about the sender.
+ wantProbeCount: 0,
+ },
+ {
+ name: "Unicast Neighbor Solicitation with source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ wantProbeCount: 1,
+ },
+ {
+ name: "Multicast Neighbor Solicitation without source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.SolicitedNodeAddr(lladdr0),
+ // "The source link-layer address option MUST be included in multicast
+ // solicitations." - RFC 4861 section 4.3
+ wantProbeCount: 0,
+ },
+ {
+ name: "Multicast Neighbor Solicitation with source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.SolicitedNodeAddr(lladdr0),
+ wantProbeCount: 1,
+ },
+ {
+ name: "Unicast Neighbor Advertisement without target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ // "When responding to unicast solicitations, the target link-layer
+ // address option can be omitted since the sender of the solicitation has
+ // the correct link-layer address; otherwise, it would not be able to
+ // send the unicast solicitation in the first place."
+ // - RFC 4861 section 4.4
+ wantConfirmationCount: 1,
+ },
+ {
+ name: "Unicast Neighbor Advertisement with target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ wantConfirmationCount: 1,
+ },
+ {
+ name: "Multicast Neighbor Advertisement without target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(false)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.IPv6AllNodesMulticastAddress,
+ // "Target link-layer address MUST be included for multicast solicitations
+ // in order to avoid infinite Neighbor Solicitation "recursion" when the
+ // peer node does not have a cache entry to return a Neighbor
+ // Advertisements message." - RFC 4861 section 4.4
+ wantConfirmationCount: 0,
+ },
+ {
+ name: "Multicast Neighbor Advertisement with target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(false)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.IPv6AllNodesMulticastAddress,
+ wantConfirmationCount: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: true,
+ })
+ {
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ nudHandler := &stubNUDHandler{}
+ ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ }
+ defer r.Release()
+
+ // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
+ r.LocalAddress = test.destination
+
+ icmp := test.createPacket()
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.RemoteAddress,
+ DstAddr: r.LocalAddress,
+ })
+ ep.HandlePacket(&r, pkt)
+
+ // Confirm the endpoint calls the correct NUDHandler method.
+ if nudHandler.probeCount != test.wantProbeCount {
+ t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount)
+ }
+ if nudHandler.confirmationCount != test.wantConfirmationCount {
+ t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 2bd8f4ece..632914dd6 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -426,7 +426,10 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt, params.Protocol)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -468,7 +471,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -569,11 +572,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
-// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
- // TODO(b/146666412): Support IPv6 header-included packets.
- return tcpip.ErrNotSupported
+// WriteHeaderIncludedPacker implements stack.NetworkEndpoint.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required checks.
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
+ ip := header.IPv6(h)
+
+ // Always set the payload length.
+ pktSize := pkt.Data.Size()
+ ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == header.IPv6Any {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ proto, _, _, _, ok := parse.IPv6(pkt)
+ if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return e.writePacket(r, nil /* gso */, pkt, proto)
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index e792ca9e2..bee18d1a8 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -57,8 +57,8 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
t.Helper()
// Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 9033a9ed5..ac20f217e 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "context"
"strings"
"testing"
"time"
@@ -398,16 +399,17 @@ func TestNeighorSolicitationResponse(t *testing.T) {
}
tests := []struct {
- name string
- nsOpts header.NDPOptionsSerializer
- nsSrcLinkAddr tcpip.LinkAddress
- nsSrc tcpip.Address
- nsDst tcpip.Address
- nsInvalid bool
- naDstLinkAddr tcpip.LinkAddress
- naSolicited bool
- naSrc tcpip.Address
- naDst tcpip.Address
+ name string
+ nsOpts header.NDPOptionsSerializer
+ nsSrcLinkAddr tcpip.LinkAddress
+ nsSrc tcpip.Address
+ nsDst tcpip.Address
+ nsInvalid bool
+ naDstLinkAddr tcpip.LinkAddress
+ naSolicited bool
+ naSrc tcpip.Address
+ naDst tcpip.Address
+ performsLinkResolution bool
}{
{
name: "Unspecified source to solicited-node multicast destination",
@@ -416,7 +418,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsSrc: header.IPv6Any,
nsDst: nicAddrSNMC,
nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
+ naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
naSolicited: false,
naSrc: nicAddr,
naDst: header.IPv6AllNodesMulticastAddress,
@@ -449,7 +451,6 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsDst: nicAddr,
nsInvalid: true,
},
-
{
name: "Specified source with 1 source ll to multicast destination",
nsOpts: header.NDPOptionsSerializer{
@@ -509,6 +510,10 @@ func TestNeighorSolicitationResponse(t *testing.T) {
naSolicited: true,
naSrc: nicAddr,
naDst: remoteAddr,
+ // Since we send a unicast solicitations to a node without an entry for
+ // the remote, the node needs to perform neighbor discovery to get the
+ // remote's link address to send the advertisement response.
+ performsLinkResolution: true,
},
{
name: "Specified source with 1 source ll to unicast destination",
@@ -615,11 +620,78 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- p, got := e.Read()
+ if test.performsLinkResolution {
+ p, got := e.ReadContext(context.Background())
+ if !got {
+ t.Fatal("expected an NDP NS response")
+ }
+
+ if p.Route.LocalAddress != nicAddr {
+ t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr)
+ }
+ if p.Route.LocalLinkAddress != nicLinkAddr {
+ t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
+ }
+ respNSDst := header.SolicitedNodeAddr(test.nsSrc)
+ if p.Route.RemoteAddress != respNSDst {
+ t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
+ }
+ if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(nicAddr),
+ checker.DstAddr(respNSDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(test.nsSrc),
+ checker.NDPNSOptions([]header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(nicLinkAddr),
+ }),
+ ))
+
+ ser := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ }
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(test.nsSrc)
+ na.Options().Serialize(ser)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
+ })
+ e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ p, got := e.ReadContext(context.Background())
if !got {
t.Fatal("expected an NDP NA response")
}
+ if p.Route.LocalAddress != test.naSrc {
+ t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc)
+ }
+ if p.Route.LocalLinkAddress != nicLinkAddr {
+ t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
+ }
+ if p.Route.RemoteAddress != test.naDst {
+ t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
+ }
if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index eba97334e..d09ebe7fa 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -123,6 +123,7 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 4d69a4de1..be61a21af 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -406,9 +406,9 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// INCOMPLETE state." - RFC 4861 section 7.2.5
case Reachable, Stale, Delay, Probe:
- sameLinkAddr := e.neigh.LinkAddr == linkAddr
+ isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr
- if !sameLinkAddr {
+ if isLinkAddrDifferent {
if !flags.Override {
if e.neigh.State == Reachable {
e.dispatchChangeEventLocked(Stale)
@@ -431,7 +431,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
}
- if flags.Solicited && (flags.Override || sameLinkAddr) {
+ if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
if e.neigh.State != Reachable {
e.dispatchChangeEventLocked(Reachable)
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index e79abebca..3ee2a3b31 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -83,15 +83,18 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Reachable | Stale | Reachable timer expired | | Changed |
// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
// | Stale | Delay | Packet sent | | Changed |
// | Delay | Reachable | Upper-layer confirmation | | Changed |
// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
// | Delay | Probe | Delay timer expired | Send probe | Changed |
// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
+// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
@@ -1370,6 +1373,77 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
nudDisp.mu.Unlock()
}
+func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, _ := entryTestSetup(c)
@@ -1752,6 +1826,100 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
nudDisp.mu.Unlock()
}
+func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 1
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, _ := entryTestSetup(c)
@@ -2665,6 +2833,115 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
nudDisp.mu.Unlock()
}
+func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ }
+ e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
func TestEntryProbeToFailed(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 25f80c1f8..b76e2d37b 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -126,6 +126,12 @@ func (r *Route) GSOMaxSize() uint32 {
return 0
}
+// ResolveWith immediately resolves a route with the specified remote link
+// address.
+func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
+ r.RemoteLinkAddress = addr
+}
+
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 38994cca1..e75f58c64 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -34,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -3498,6 +3499,52 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
}
+func TestResolveWith(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address([]byte{192, 168, 1, 58}),
+ PrefixLen: 24,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
+
+ remoteAddr := tcpip.Address([]byte{192, 168, 1, 59})
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ // Should initially require resolution.
+ if !r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = false, want = true")
+ }
+
+ // Manually resolving the route should no longer require resolution.
+ r.ResolveWith("\x01")
+ if r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = true, want = false")
+ }
+}
+
// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
// associated address is removed should not cause a panic.
func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index c42bb0991..d77848d61 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -111,6 +111,7 @@ var (
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
+ ErrMalformedHeader = &Error{msg: "header is malformed"}
)
var messageToError map[string]*Error
@@ -159,6 +160,7 @@ func StringToError(s string) *Error {
ErrBroadcastDisabled,
ErrNotPermitted,
ErrAddressFamilyNotSupported,
+ ErrMalformedHeader,
}
messageToError = make(map[string]*Error)
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index c67e0ba95..3ae6cc221 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -81,6 +81,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.ID = r.id
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
+ ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto}
ep.RegisterNICID = r.route.NICID()
ep.boundPortFlags = ep.portFlags