diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv6.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 405 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 62 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 160 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 254 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 44 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 98 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 277 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 47 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/forwarder.go | 1 |
18 files changed, 1275 insertions, 156 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index d4d785cca..6f81b0164 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker { } } +// IPPayload creates a checker that checks the payload. +func IPPayload(payload []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + got := h[0].Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(got) == 0 && len(payload) == 0 { + return + } + + if diff := cmp.Diff(payload, got); diff != "" { + t.Errorf("payload mismatch (-want +got):\n%s", diff) + } + } +} + // IPv4Options returns a checker that checks the options in an IPv4 packet. func IPv4Options(want []byte) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 6be31beeb..4303fc5d5 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -49,11 +49,6 @@ const ( // neighbor advertisement packet. ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize - // ICMPv6NeighborAdvertSize is size of a neighbor advertisement - // including the NDP Target Link Layer option for an Ethernet - // address. - ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize - // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet. ICMPv6EchoMinimumSize = 8 diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 59710352b..c118a2929 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -12,6 +12,7 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d436873b6..c1848a19c 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,11 +15,13 @@ package ip_test import ( + "strings" "testing" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -1020,3 +1022,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer _, _ = pkt.NetworkHeader().Consume(netHdrLen) return pkt } + +func TestWriteHeaderIncludedPacket(t *testing.T) { + const ( + nicID = 1 + transportProto = 5 + + dataLen = 4 + optionsLen = 4 + ) + + dataBuf := [dataLen]byte{1, 2, 3, 4} + data := dataBuf[:] + + ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1} + ipv4Options := ipv4OptionsBuf[:] + + ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} + ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] + + var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte + ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] + if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + nicAddr tcpip.Address + remoteAddr tcpip.Address + pktGen func(*testing.T, tcpip.Address) buffer.View + checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) + expectedErr *tcpip.Error + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with IHL too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 1, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 minimum size", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(header.IPv4MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv4 with options", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) + totalLen := ipHdrLen + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(ipHdrLen)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHdrLen), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) + } + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6 with extension header", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), + checker.IPPayload(ipv6PayloadWithExtHdr), + ) + }, + }, + { + name: "IPv6 minimum size", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(header.IPv6MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv6 too small", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + srcAddr tcpip.Address + }{ + { + name: "unspecified source", + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + }, + { + name: "random source", + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + e := channel.New(1, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) + + r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + } + defer r.Release() + + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + })); err != test.expectedErr { + t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + test.checker(t, pkt.Pkt, subTest.srcAddr) + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c5ac7b8b5..7e2e53523 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -237,7 +237,10 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -347,30 +350,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacket writes a packet already containing a network -// header through the given route. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrInvalidOptionValue + return tcpip.ErrMalformedHeader } ip := header.IPv4(h) - if !ip.IsValid(pkt.Data.Size()) { - return tcpip.ErrInvalidOptionValue - } // Always set the total length. - ip.SetTotalLength(uint16(pkt.Data.Size())) + pktSize := pkt.Data.Size() + ip.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + if ip.SourceAddress() == header.IPv4Any { ip.SetSourceAddress(r.LocalAddress) } - // Set the destination. If the packet already included a destination, - // it will be part of the route. + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. ip.SetDestinationAddress(r.RemoteAddress) // Set the packet ID when zero. @@ -387,19 +387,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if r.Loop&stack.PacketLoop != 0 { - e.HandlePacket(r, pkt.Clone()) - } - if r.Loop&stack.PacketOut == 0 { - return nil + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader } - if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return err - } - r.Stats().IP.PacketsSent.Increment() - return nil + return e.writePacket(r, nil /* gso */, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9916d783f..819b5d71f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,9 +15,9 @@ package ipv4_test import ( - "bytes" "context" "encoding/hex" + "fmt" "math" "net" "testing" @@ -243,7 +243,7 @@ func TestIPv4Sanity(t *testing.T) { // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv4EmptySubnet, NIC: nicID, }, @@ -369,11 +369,10 @@ func TestIPv4Sanity(t *testing.T) { // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) - vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32) error { + // Make a complete array of the sourcePacket packet. + source := header.IPv4(packets[0].NetworkHeader().View()) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make @@ -384,46 +383,49 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI sourceCopy.SetTotalLength(0) var offset uint16 // Build up an array of the bytes sent. - var reassembledPayload []byte + var reassembledPayload buffer.VectorisedView for i, packet := range packets { // Confirm that the packet is valid. allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) + fragmentIPHeader := header.IPv4(allBytes.ToView()) + if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) + if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { + return fmt.Errorf("fragment #%d: fragmentIPHeader.CalculateChecksum() got %#x, want %#x", i, got, want) } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) + if got := len(fragmentIPHeader); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) } - if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) + if got, want := packet.AvailableHeaderBytes(), sourcePacket.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { + return fmt.Errorf("fragment #%d: should have the same available space for prepending as source: got %d, want %d", i, got, want) } - if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { - t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { + return fmt.Errorf("fragment #%d: has wrong network protocol number: got %d, want %d", i, got, want) } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) + reassembledPayload.AppendView(packet.TransportHeader().View()) + reassembledPayload.Append(packet.Data) + offset += fragmentIPHeader.TotalLength() - uint16(fragmentIPHeader.HeaderLength()) // Clear out the checksum and length from the ip because we can't compare // it. - sourceCopy.SetTotalLength(uint16(len(ip))) + sourceCopy.SetTotalLength(uint16(len(fragmentIPHeader))) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) + if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader[:fragmentIPHeader.HeaderLength()] mismatch (-want +got):\n%s", i, diff) } } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) + expected := buffer.View(source[source.HeaderLength():]) + if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } + + return nil } func TestFragmentation(t *testing.T) { @@ -477,7 +479,9 @@ func TestFragmentation(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - compareFragments(t, ep.WrittenPackets, source, ft.mtu) + if err := compareFragments(ep.WrittenPackets, source, ft.mtu); err != nil { + t.Error(err) + } }) } } @@ -1633,7 +1637,7 @@ func TestPacketQueing(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), NIC: nicID, }, diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7be35c78b..ead6bedcb 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -252,26 +252,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - it, err := ns.Options().Iter(false /* check */) - if err != nil { - // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() - return - } + var sourceLinkAddr tcpip.LinkAddress + { + it, err := ns.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the + // packet. + received.Invalid.Increment() + return + } - sourceLinkAddr, ok := getSourceLinkAddr(it) - if !ok { - received.Invalid.Increment() - return + sourceLinkAddr, ok = getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } } - unspecifiedSource := r.RemoteAddress == header.IPv6Any - // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. + unspecifiedSource := r.RemoteAddress == header.IPv6Any if len(sourceLinkAddr) == 0 { if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { received.Invalid.Increment() @@ -297,54 +300,71 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - // ICMPv6 Neighbor Solicit messages are always sent to - // specially crafted IPv6 multicast addresses. As a result, the - // route we end up with here has as its LocalAddress such a - // multicast address. It would be nonsense to claim that our - // source address is a multicast address, so we manually set - // the source address to the target address requested in the - // solicit message. Since that requires mutating the route, we - // must first clone it. - r := r.Clone() - defer r.Release() - r.LocalAddress = targetAddr - - // As per RFC 4861 section 7.2.4, if the the source of the solicitation is - // the unspecified address, the node MUST set the Solicited flag to zero and - // multicast the advertisement to the all-nodes address. - solicited := true + // As per RFC 4861 section 7.2.4: + // + // If the source of the solicitation is the unspecified address, the node + // MUST [...] and multicast the advertisement to the all-nodes address. + // + remoteAddr := r.RemoteAddress if unspecifiedSource { - solicited = false - r.RemoteAddress = header.IPv6AllNodesMulticastAddress + remoteAddr = header.IPv6AllNodesMulticastAddress } - // If the NS has a source link-layer option, use the link address it - // specifies as the remote link address for the response instead of the - // source link address of the packet. + // Even if we were able to receive a packet from some remote, we may not + // have a route to it - the remote may be blocked via routing rules. We must + // always consult our routing table and find a route to the remote before + // sending any packet. + r, err := e.protocol.stack.FindRoute(e.nic.ID(), targetAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // If the NS has a source link-layer option, resolve the route immediately + // to avoid querying the neighbor table when the neighbor entry was updated + // as probing the neighbor table for a link address will transition the + // entry's state from stale to delay. + // + // Note, if the source link address is unspecified and this is a unicast + // solicitation, we may need to perform neighbor discovery to send the + // neighbor advertisement response. This is expected as per RFC 4861 section + // 7.2.4: + // + // Because unicast Neighbor Solicitations are not required to include a + // Source Link-Layer Address, it is possible that a node sending a + // solicited Neighbor Advertisement does not have a corresponding link- + // layer address for its neighbor in its Neighbor Cache. In such + // situations, a node will first have to use Neighbor Discovery to + // determine the link-layer address of its neighbor (i.e., send out a + // multicast Neighbor Solicitation). // - // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link - // address cache for the right destination link address instead of manually - // patching the route with the remote link address if one is specified in a - // Source Link-Layer Address option. if len(sourceLinkAddr) != 0 { - r.RemoteLinkAddress = sourceLinkAddr + r.ResolveWith(sourceLinkAddr) } optsSerializer := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), + header.NDPTargetLinkLayerAddressOption(e.nic.LinkAddress()), } + neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize, }) - packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) - na.SetSolicitedFlag(solicited) + + // As per RFC 4861 section 7.2.4: + // + // If the source of the solicitation is the unspecified address, the node + // MUST set the Solicited flag to zero and [..]. Otherwise, the node MUST + // set the Solicited flag to one and [..]. + // + na.SetSolicitedFlag(!unspecifiedSource) na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) - opts := na.Options() - opts.Serialize(optsSerializer) + na.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) @@ -361,7 +381,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize { + if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize { received.Invalid.Increment() return } @@ -419,19 +439,19 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - if len(targetLinkAddr) != 0 { - if e.nud == nil { + if e.nud == nil { + if len(targetLinkAddr) != 0 { e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) - return } - - e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ - Solicited: na.SolicitedFlag(), - Override: na.OverrideFlag(), - IsRouter: na.RouterFlag(), - }) + return } + e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + Solicited: na.SolicitedFlag(), + Override: na.OverrideFlag(), + IsRouter: na.RouterFlag(), + }) + case header.ICMPv6EchoRequest: received.EchoRequest.Increment() icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) @@ -620,18 +640,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } } -const ( - ndpSolicitedFlag = 1 << 6 - ndpOverrideFlag = 1 << 5 - - ndpOptSrcLinkAddr = 1 - ndpOptDstLinkAddr = 2 - - icmpV6FlagOffset = 4 - icmpV6OptOffset = 24 - icmpV6LengthOffset = 25 -) - var _ stack.LinkAddressResolver = (*protocol)(nil) // LinkAddressProtocol implements stack.LinkAddressResolver. @@ -647,6 +655,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r := stack.Route{ LocalAddress: localAddr, RemoteAddress: addr, + LocalLinkAddress: linkEP.LinkAddress(), RemoteLinkAddress: remoteLinkAddr, } @@ -658,17 +667,20 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) } + optsSerializer := header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()), + } + neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize, }) - icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - icmpHdr.SetType(header.ICMPv6NeighborSolicit) - copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) - icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr - icmpHdr[icmpV6LengthOffset] = 1 - copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) + packet.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(packet.NDPPayload()) + ns.SetTargetAddress(addr) + ns.Options().Serialize(optsSerializer) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) length := uint16(pkt.Size()) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 3affcc4e4..8dc33c560 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -101,14 +101,19 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { } -type stubNUDHandler struct{} +type stubNUDHandler struct { + probeCount int + confirmationCount int +} var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { + s.probeCount++ } -func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { + s.confirmationCount++ } func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { @@ -118,6 +123,12 @@ var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { stack.NetworkLinkEndpoint + + linkAddr tcpip.LinkAddress +} + +func (i *testInterface) LinkAddress() tcpip.LinkAddress { + return i.linkAddr } func (*testInterface) ID() tcpip.NICID { @@ -1492,3 +1503,240 @@ func TestPacketQueing(t *testing.T) { }) } } + +func TestCallsToNeighborCache(t *testing.T) { + tests := []struct { + name string + createPacket func() header.ICMPv6 + multicast bool + source tcpip.Address + destination tcpip.Address + wantProbeCount int + wantConfirmationCount int + }{ + { + name: "Unicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "The source link-layer address option SHOULD be included in unicast + // solicitations." - RFC 4861 section 4.3 + // + // A Neighbor Advertisement needs to be sent in response, but the + // Neighbor Cache shouldn't be updated since we have no useful + // information about the sender. + wantProbeCount: 0, + }, + { + name: "Unicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantProbeCount: 1, + }, + { + name: "Multicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + // "The source link-layer address option MUST be included in multicast + // solicitations." - RFC 4861 section 4.3 + wantProbeCount: 0, + }, + { + name: "Multicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + wantProbeCount: 1, + }, + { + name: "Unicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "When responding to unicast solicitations, the target link-layer + // address option can be omitted since the sender of the solicitation has + // the correct link-layer address; otherwise, it would not be able to + // send the unicast solicitation in the first place." + // - RFC 4861 section 4.4 + wantConfirmationCount: 1, + }, + { + name: "Unicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantConfirmationCount: 1, + }, + { + name: "Multicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + // "Target link-layer address MUST be included for multicast solicitations + // in order to avoid infinite Neighbor Solicitation "recursion" when the + // peer node does not have a cache entry to return a Neighbor + // Advertisements message." - RFC 4861 section 4.4 + wantConfirmationCount: 0, + }, + { + name: "Multicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + wantConfirmationCount: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: true, + }) + { + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + nudHandler := &stubNUDHandler{} + ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + } + defer r.Release() + + // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. + r.LocalAddress = test.destination + + icmp := test.createPacket() + icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.RemoteAddress, + DstAddr: r.LocalAddress, + }) + ep.HandlePacket(&r, pkt) + + // Confirm the endpoint calls the correct NUDHandler method. + if nudHandler.probeCount != test.wantProbeCount { + t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount) + } + if nudHandler.confirmationCount != test.wantConfirmationCount { + t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 2bd8f4ece..632914dd6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -426,7 +426,10 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt, params.Protocol) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -468,7 +471,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -569,11 +572,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet -// supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - // TODO(b/146666412): Support IPv6 header-included packets. - return tcpip.ErrNotSupported +// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { + // The packet already has an IP header, but there are a few required checks. + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return tcpip.ErrMalformedHeader + } + ip := header.IPv6(h) + + // Always set the payload length. + pktSize := pkt.Data.Size() + ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + + // Set the source address when zero. + if ip.SourceAddress() == header.IPv6Any { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. + ip.SetDestinationAddress(r.RemoteAddress) + + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + proto, _, _, _, ok := parse.IPv6(pkt) + if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader + } + + return e.writePacket(r, nil /* gso */, pkt, proto) } // HandlePacket is called by the link layer when new ipv6 packets arrive for diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index e792ca9e2..bee18d1a8 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -57,8 +57,8 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst t.Helper() // Receive ICMP packet. - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize)) pkt.SetType(header.ICMPv6NeighborAdvert) pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 9033a9ed5..ac20f217e 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "context" "strings" "testing" "time" @@ -398,16 +399,17 @@ func TestNeighorSolicitationResponse(t *testing.T) { } tests := []struct { - name string - nsOpts header.NDPOptionsSerializer - nsSrcLinkAddr tcpip.LinkAddress - nsSrc tcpip.Address - nsDst tcpip.Address - nsInvalid bool - naDstLinkAddr tcpip.LinkAddress - naSolicited bool - naSrc tcpip.Address - naDst tcpip.Address + name string + nsOpts header.NDPOptionsSerializer + nsSrcLinkAddr tcpip.LinkAddress + nsSrc tcpip.Address + nsDst tcpip.Address + nsInvalid bool + naDstLinkAddr tcpip.LinkAddress + naSolicited bool + naSrc tcpip.Address + naDst tcpip.Address + performsLinkResolution bool }{ { name: "Unspecified source to solicited-node multicast destination", @@ -416,7 +418,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsSrc: header.IPv6Any, nsDst: nicAddrSNMC, nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, + naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), naSolicited: false, naSrc: nicAddr, naDst: header.IPv6AllNodesMulticastAddress, @@ -449,7 +451,6 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsDst: nicAddr, nsInvalid: true, }, - { name: "Specified source with 1 source ll to multicast destination", nsOpts: header.NDPOptionsSerializer{ @@ -509,6 +510,10 @@ func TestNeighorSolicitationResponse(t *testing.T) { naSolicited: true, naSrc: nicAddr, naDst: remoteAddr, + // Since we send a unicast solicitations to a node without an entry for + // the remote, the node needs to perform neighbor discovery to get the + // remote's link address to send the advertisement response. + performsLinkResolution: true, }, { name: "Specified source with 1 source ll to unicast destination", @@ -615,11 +620,78 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - p, got := e.Read() + if test.performsLinkResolution { + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NS response") + } + + if p.Route.LocalAddress != nicAddr { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr) + } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + respNSDst := header.SolicitedNodeAddr(test.nsSrc) + if p.Route.RemoteAddress != respNSDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) + } + if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(nicAddr), + checker.DstAddr(respNSDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(test.nsSrc), + checker.NDPNSOptions([]header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(nicLinkAddr), + }), + )) + + ser := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + } + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(test.nsSrc) + na.Options().Serialize(ser) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, + }) + e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + p, got := e.ReadContext(context.Background()) if !got { t.Fatal("expected an NDP NA response") } + if p.Route.LocalAddress != test.naSrc { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) + } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + if p.Route.RemoteAddress != test.naDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) + } if p.Route.RemoteLinkAddress != test.naDstLinkAddr { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index eba97334e..d09ebe7fa 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -123,6 +123,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 4d69a4de1..be61a21af 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -406,9 +406,9 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // INCOMPLETE state." - RFC 4861 section 7.2.5 case Reachable, Stale, Delay, Probe: - sameLinkAddr := e.neigh.LinkAddr == linkAddr + isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr - if !sameLinkAddr { + if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { e.dispatchChangeEventLocked(Stale) @@ -431,7 +431,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } } - if flags.Solicited && (flags.Override || sameLinkAddr) { + if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { if e.neigh.State != Reachable { e.dispatchChangeEventLocked(Reachable) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index e79abebca..3ee2a3b31 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -83,15 +83,18 @@ func eventDiffOptsWithSort() []cmp.Option { // | Reachable | Stale | Reachable timer expired | | Changed | // | Reachable | Stale | Probe or confirmation w/ different address | | Changed | // | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Stale | Stale | Override confirmation | Update LinkAddr | Changed | // | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | // | Stale | Delay | Packet sent | | Changed | // | Delay | Reachable | Upper-layer confirmation | | Changed | // | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Delay | Stale | Probe or confirmation w/ different address | | Changed | // | Delay | Probe | Delay timer expired | Send probe | Changed | // | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | // | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Probe | Stale | Probe or confirmation w/ different address | | Changed | // | Probe | Probe | Retransmit timer expired | Send probe | Changed | // | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | @@ -1370,6 +1373,77 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { nudDisp.mu.Unlock() } +func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) @@ -1752,6 +1826,100 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { nudDisp.mu.Unlock() } +func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) @@ -2665,6 +2833,115 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin nudDisp.mu.Unlock() } +func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 25f80c1f8..b76e2d37b 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -126,6 +126,12 @@ func (r *Route) GSOMaxSize() uint32 { return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 38994cca1..e75f58c64 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -34,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -3498,6 +3499,52 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } } +func TestResolveWith(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + ep := channel.New(0, defaultMTU, "") + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address([]byte{192, 168, 1, 58}), + PrefixLen: 24, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) + + remoteAddr := tcpip.Address([]byte{192, 168, 1, 59}) + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) + } + defer r.Release() + + // Should initially require resolution. + if !r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = false, want = true") + } + + // Manually resolving the route should no longer require resolution. + r.ResolveWith("\x01") + if r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = true, want = false") + } +} + // TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its // associated address is removed should not cause a panic. func TestRouteReleaseAfterAddrRemoval(t *testing.T) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c42bb0991..d77848d61 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -111,6 +111,7 @@ var ( ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} ErrNotPermitted = &Error{msg: "operation not permitted"} ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} + ErrMalformedHeader = &Error{msg: "header is malformed"} ) var messageToError map[string]*Error @@ -159,6 +160,7 @@ func StringToError(s string) *Error { ErrBroadcastDisabled, ErrNotPermitted, ErrAddressFamilyNotSupported, + ErrMalformedHeader, } messageToError = make(map[string]*Error) diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index c67e0ba95..3ae6cc221 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -81,6 +81,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.ID = r.id ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort + ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto} ep.RegisterNICID = r.route.NICID() ep.boundPortFlags = ep.portFlags |