summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD14
-rw-r--r--pkg/tcpip/checker/checker.go203
-rw-r--r--pkg/tcpip/header/checksum_test.go94
-rw-r--r--pkg/tcpip/header/icmpv4.go30
-rw-r--r--pkg/tcpip/header/icmpv6.go27
-rw-r--r--pkg/tcpip/header/ipv6.go53
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go336
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go356
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go42
-rw-r--r--pkg/tcpip/header/ipv6_test.go44
-rw-r--r--pkg/tcpip/link/channel/channel.go14
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go9
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go7
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go6
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go8
-rw-r--r--pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go1
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go5
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go2
-rw-r--r--pkg/tcpip/link/tun/device.go4
-rw-r--r--pkg/tcpip/network/arp/arp_test.go9
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD2
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go4
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go127
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go137
-rw-r--r--pkg/tcpip/network/ip/BUILD1
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go321
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go620
-rw-r--r--pkg/tcpip/network/ip_test.go91
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go142
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go61
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go41
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go12
-rw-r--r--pkg/tcpip/network/ipv6/BUILD3
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go18
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go135
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go301
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go196
-rw-r--r--pkg/tcpip/network/ipv6/mld.go142
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go275
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go39
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go104
-rw-r--r--pkg/tcpip/network/multicast_group_test.go300
-rw-r--r--pkg/tcpip/socketops.go234
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go8
-rw-r--r--pkg/tcpip/stack/forwarding_test.go32
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go135
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go110
-rw-r--r--pkg/tcpip/stack/ndp_test.go38
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go95
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go491
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go137
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go457
-rw-r--r--pkg/tcpip/stack/nic.go113
-rw-r--r--pkg/tcpip/stack/nud.go21
-rw-r--r--pkg/tcpip/stack/nud_test.go53
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go32
-rw-r--r--pkg/tcpip/stack/route.go172
-rw-r--r--pkg/tcpip/stack/stack.go31
-rw-r--r--pkg/tcpip/stack/stack_test.go188
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go15
-rw-r--r--pkg/tcpip/stack/transport_test.go9
-rw-r--r--pkg/tcpip/tcpip.go76
-rw-r--r--pkg/tcpip/tcpip_test.go40
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go20
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go32
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go79
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go21
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go134
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go12
-rw-r--r--pkg/tcpip/transport/tcp/snd.go35
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go42
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go99
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go10
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go148
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go123
82 files changed, 5061 insertions, 2757 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 27f96a3ac..89b765f1b 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,10 +1,24 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "sock_err_list",
+ out = "sock_err_list.go",
+ package = "tcpip",
+ prefix = "sockError",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*SockError",
+ "Linker": "*SockError",
+ },
+)
+
go_library(
name = "tcpip",
srcs = [
+ "sock_err_list.go",
"socketops.go",
"tcpip.go",
"time_unsafe.go",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index d3ae56ac6..91971b687 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -117,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
+ case *ipv6HeaderWithExtHdr:
+ v = ip.HopLimit()
+ default:
+ t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
}
if v != ttl {
t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
@@ -321,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
+// field in ControlMessages.
+func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasOriginalDstAddress {
+ t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
+ } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
+ t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -1400,3 +1417,189 @@ func IGMPGroupAddress(want tcpip.Address) TransportChecker {
}
}
}
+
+// IPv6ExtHdrChecker is a function to check an extension header.
+type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
+
+// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
+func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Error("not a valid IPv6 packet")
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var rawPayloadHeader header.IPv6RawPayloadHeader
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
+ return
+ }
+ r, ok := h.(header.IPv6RawPayloadHeader)
+ if ok {
+ rawPayloadHeader = r
+ break
+ }
+ }
+
+ networkHeader := ipv6HeaderWithExtHdr{
+ IPv6: ipv6,
+ transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
+ payload: rawPayloadHeader.Buf.ToView(),
+ }
+
+ for _, checker := range checkers {
+ checker(t, []header.Network{&networkHeader})
+ }
+}
+
+// IPv6ExtHdr checks for the presence of extension headers.
+//
+// All the extension headers in headers will be checked exhaustively in the
+// order provided.
+func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
+ if !ok {
+ t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
+ buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
+ )
+
+ for _, check := range headers {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ return
+ }
+ check(t, h)
+ }
+ // Validate we consumed all headers.
+ //
+ // The next one over should be a raw payload and then iterator should
+ // terminate.
+ wantDone := false
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done != wantDone {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
+ return
+ }
+ if done {
+ break
+ }
+ if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
+ t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
+ continue
+ }
+ wantDone = true
+ }
+ }
+}
+
+var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
+
+// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
+// extension headers into consideration, which is not the case with vanilla
+// header.IPv6.
+type ipv6HeaderWithExtHdr struct {
+ header.IPv6
+ transport tcpip.TransportProtocolNumber
+ payload []byte
+}
+
+// TransportProtocol implements header.Network.
+func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
+ return h.transport
+}
+
+// Payload implements header.Network.
+func (h *ipv6HeaderWithExtHdr) Payload() []byte {
+ return h.payload
+}
+
+// IPv6ExtHdrOptionChecker is a function to check an extension header option.
+type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
+
+// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
+// extension header and validates the containing options with checkers.
+//
+// checkers must exhaustively contain all the expected options.
+func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
+ return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
+ t.Helper()
+
+ hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
+ if !ok {
+ t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
+ return
+ }
+ optionsIterator := hbh.Iter()
+ for _, f := range checkers {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ f(t, opt)
+ }
+ // Validate all options were consumed.
+ for {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if !done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ if done {
+ break
+ }
+ }
+ }
+}
+
+// IPv6RouterAlert validates that an extension header option is the RouterAlert
+// option and matches on its value.
+func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
+ if !ok {
+ t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
+ return
+ }
+ if routerAlert.Value != want {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 309403482..5ab20ee86 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -19,6 +19,7 @@ package header_test
import (
"fmt"
"math/rand"
+ "sync"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) {
}
}
}
+
+func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) {
+ // icmpChecksum should not do any modifications of the header to
+ // calculate its checksum. Let's call it from a few go-routines and the
+ // race detector will trigger a warning if there are any concurrent
+ // read/write accesses.
+
+ const concurrency = 5
+ start := make(chan int)
+ ready := make(chan bool, concurrency)
+ var wg sync.WaitGroup
+ wg.Add(concurrency)
+ defer wg.Wait()
+
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ defer wg.Done()
+
+ ready <- true
+ <-start
+
+ if got := headerChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ if got := icmpChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ }()
+ }
+ for i := 0; i < concurrency; i++ {
+ <-ready
+ }
+ close(start)
+}
+
+func TestICMPv4Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:5]),
+ buffer.NewViewFromBytes(buf[5:]),
+ })
+
+ want := header.Checksum(vv.ToView(), 0)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv4Checksum(h, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}
+
+func TestICMPv6Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:7]),
+ buffer.NewViewFromBytes(buf[7:10]),
+ buffer.NewViewFromBytes(buf[10:]),
+ })
+
+ dst := header.IPv6Loopback
+ src := header.IPv6Loopback
+
+ want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+ want = header.Checksum(vv.ToView(), want)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv6Checksum(h, src, dst, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 2f13dea6a..5f9b8e9e2 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -199,17 +200,24 @@ func (b ICMPv4) SetSequence(sequence uint16) {
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := uint16(0)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
+ xsum := ChecksumVV(vv, 0)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
+ return ^xsum
+}
- return xsum
+// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when
+// a packet having a `net` header causing an ICMP error.
+func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin {
+ switch net {
+ case IPv4ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP
+ case IPv6ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP6
+ default:
+ panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net))
+ }
}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 2eef64b4d..eca9750ab 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -265,22 +265,13 @@ func (b ICMPv6) Payload() []byte {
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := Checksum([]byte(src), 0)
- xsum = Checksum([]byte(dst), xsum)
- var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
- xsum = Checksum(upperLayerLength[:], xsum)
- xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
-
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
-
- return xsum
+ xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+
+ xsum = ChecksumVV(vv, xsum)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
+
+ return ^xsum
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 55d09355a..5580d6a78 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -18,7 +18,6 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
- "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -48,11 +47,13 @@ type IPv6Fields struct {
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
- // PayloadLength is the "payload length" field of an IPv6 packet.
+ // PayloadLength is the "payload length" field of an IPv6 packet, including
+ // the length of all extension headers.
PayloadLength uint16
- // NextHeader is the "next header" field of an IPv6 packet.
- NextHeader uint8
+ // TransportProtocol is the transport layer protocol number. Serialized in the
+ // last "next header" field of the IPv6 header + extension headers.
+ TransportProtocol tcpip.TransportProtocolNumber
// HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
@@ -62,6 +63,9 @@ type IPv6Fields struct {
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr tcpip.Address
+
+ // ExtensionHeaders are the extension headers following the IPv6 header.
+ ExtensionHeaders IPv6ExtHdrSerializer
}
// IPv6 represents an ipv6 header stored in a byte array.
@@ -148,13 +152,17 @@ const (
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
// be contained within this subnet.
-var IPv6EmptySubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
- if err != nil {
- panic(err)
- }
- return subnet
-}()
+var IPv6EmptySubnet = tcpip.AddressWithPrefix{
+ Address: IPv6Any,
+ PrefixLen: 0,
+}.Subnet()
+
+// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined
+// by RFC 4291 section 2.5.5.
+var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{
+ Address: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00",
+ PrefixLen: 96,
+}.Subnet()
// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
// by RFC 4291 section 2.5.6.
@@ -253,12 +261,14 @@ func (IPv6) SetChecksum(uint16) {
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
+ extHdr := b[IPv6MinimumSize:]
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
- b[IPv6NextHeaderOffset] = i.NextHeader
b[hopLimit] = i.HopLimit
b.SetSourceAddress(i.SrcAddr)
b.SetDestinationAddress(i.DstAddr)
+ nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
+ b[IPv6NextHeaderOffset] = nextHeader
}
// IsValid performs basic validation on the packet.
@@ -286,7 +296,7 @@ func IsV4MappedAddress(addr tcpip.Address) bool {
return false
}
- return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
+ return IPv4MappedIPv6Subnet.Contains(addr)
}
// IsV6MulticastAddress determines if the provided address is an IPv6
@@ -392,17 +402,6 @@ func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
}
-// IsV6UniqueLocalAddress determines if the provided address is an IPv6
-// unique-local address (within the prefix FC00::/7).
-func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
- if len(addr) != IPv6AddressSize {
- return false
- }
- // According to RFC 4193 section 3.1, a unique local address has the prefix
- // FC00::/7.
- return (addr[0] & 0xfe) == 0xfc
-}
-
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
//
@@ -449,9 +448,6 @@ const (
// LinkLocalScope indicates a link-local address.
LinkLocalScope IPv6AddressScope = iota
- // UniqueLocalScope indicates a unique-local address.
- UniqueLocalScope
-
// GlobalScope indicates a global address.
GlobalScope
)
@@ -469,9 +465,6 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
case IsV6LinkLocalAddress(addr):
return LinkLocalScope, nil
- case IsV6UniqueLocalAddress(addr):
- return UniqueLocalScope, nil
-
default:
return GlobalScope, nil
}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 571eae233..f18981332 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -18,9 +18,12 @@ import (
"bufio"
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
+ "math"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -75,8 +78,8 @@ const (
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
- // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
- // discard from the Fragment Offset.
+ // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
+ // Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetShift = 3
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
@@ -114,6 +117,37 @@ const (
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
)
+// padIPv6OptionsLength returns the total length for IPv6 options of length l
+// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
+func padIPv6OptionsLength(length int) int {
+ return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
+}
+
+// padIPv6Option fills b with the appropriate padding options depending on its
+// length.
+func padIPv6Option(b []byte) {
+ switch len(b) {
+ case 0: // No padding needed.
+ case 1: // Pad with Pad1.
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
+ default: // Pad with PadN.
+ s := b[ipv6ExtHdrOptionPayloadOffset:]
+ for i := range s {
+ s[i] = 0
+ }
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
+ b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
+ }
+}
+
+// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
+// serialize an option at headerOffset with alignment requirements
+// [align]n + alignOffset.
+func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
+ padLen := headerOffset - alignOffset
+ return ((padLen + align - 1) & ^(align - 1)) - padLen
+}
+
// IPv6PayloadHeader is implemented by the various headers that can be found
// in an IPv6 payload.
//
@@ -206,29 +240,55 @@ type IPv6ExtHdrOption interface {
isIPv6ExtHdrOption()
}
-// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
-type IPv6ExtHdrOptionIndentifier uint8
+// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIdentifier uint8
const (
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
- ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
// ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
- ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
+
+ // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
+ // Alert Hop by Hop option as defined in RFC 2711 section 2.1.
+ ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
+
+ // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
+ // option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionTypeOffset = 0
+
+ // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionLengthOffset = 1
+
+ // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionPayloadOffset = 2
)
+// ipv6UnknownActionFromIdentifier maps an extension header option's
+// identifier's high bits to the action to take when the identifier is unknown.
+func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
+// is malformed.
+var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
+
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
- Identifier IPv6ExtHdrOptionIndentifier
+ Identifier IPv6ExtHdrOptionIdentifier
Data []byte
}
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
- return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+ return ipv6UnknownActionFromIdentifier(o.Identifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
@@ -251,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// options buffer has been exhausted and we are done iterating.
return nil, true, nil
}
- id := IPv6ExtHdrOptionIndentifier(temp)
+ id := IPv6ExtHdrOptionIdentifier(temp)
// If the option identifier indicates the option is a Pad1 option, then we
// know the option does not have Length and Data fields. End processing of
@@ -294,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
continue
+ case ipv6RouterAlertHopByHopOptionIdentifier:
+ var routerAlertValue [ipv6RouterAlertPayloadLength]byte
+ if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil {
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ default:
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
+ }
+ } else if n != int(length) {
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ }
+ return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
bytes := make([]byte, length)
if n, err := io.ReadFull(&i.reader, bytes); err != nil {
@@ -609,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
}
+
+// IPv6SerializableExtHdr provides serialization for IPv6 extension
+// headers.
+type IPv6SerializableExtHdr interface {
+ // identifier returns the assigned IPv6 header identifier for this extension
+ // header.
+ identifier() IPv6ExtensionHeaderIdentifier
+
+ // length returns the total serialized length in bytes of this extension
+ // header, including the common next header and length fields.
+ length() int
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer and with the provided nextHeader value.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto returns the number of bytes that was used to serialize the
+ // receiver. Implementers must only use the number of bytes required to
+ // serialize the receiver. Callers MAY provide a larger buffer than required
+ // to serialize into.
+ serializeInto(nextHeader uint8, b []byte) int
+}
+
+var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
+
+// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
+// options extension header.
+type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
+
+const (
+ // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
+ // in a hop by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrNextHeaderOffset = 0
+
+ // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
+ // by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrLengthOffset = 1
+
+ // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
+ // hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrOptionsOffset = 2
+
+ // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
+ // words in a hop by hop extension header's length field, as stated in RFC
+ // 8200 section 4.3:
+ // Length of the Hop-by-Hop Options header in 8-octet units,
+ // not including the first 8 octets.
+ ipv6HopByHopExtHdrUnaccountedLenWords = 1
+)
+
+// identifier implements IPv6SerializableExtHdr.
+func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6HopByHopOptionsExtHdrIdentifier
+}
+
+// length implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) length() int {
+ var total int
+ for _, opt := range h {
+ align, alignOffset := opt.alignment()
+ total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
+ total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
+ }
+ // Account for next header and total length fields and add padding.
+ return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
+ totalLength := ipv6HopByHopExtHdrOptionsOffset
+ for _, opt := range h {
+ // Calculate alignment requirements and pad buffer if necessary.
+ align, alignOffset := opt.alignment()
+ padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
+ if padLen != 0 {
+ padIPv6Option(optBuffer[:padLen])
+ totalLength += padLen
+ optBuffer = optBuffer[padLen:]
+ }
+
+ l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
+ optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
+ optBuffer[ipv6ExtHdrOptionLengthOffset] = l
+ l += ipv6ExtHdrOptionPayloadOffset
+ totalLength += int(l)
+ optBuffer = optBuffer[l:]
+ }
+ padded := padIPv6OptionsLength(totalLength)
+ if padded != totalLength {
+ padIPv6Option(optBuffer[:padded-totalLength])
+ totalLength = padded
+ }
+ wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
+ if wordsLen > math.MaxUint8 {
+ panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
+ }
+ b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
+ b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
+ return totalLength
+}
+
+// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
+type IPv6SerializableHopByHopOption interface {
+ // identifier returns the option identifier of this Hop by Hop option.
+ identifier() IPv6ExtHdrOptionIdentifier
+
+ // length returns the *payload* size of the option (not considering the type
+ // and length fields).
+ length() uint8
+
+ // alignment returns the alignment requirements from this option.
+ //
+ // Alignment requirements take the form [align]n + offset as specified in
+ // RFC 8200 section 4.2. The alignment requirement is on the offset between
+ // the option type byte and the start of the hop by hop header.
+ //
+ // align must be a power of 2.
+ alignment() (align int, offset int)
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) uint8
+}
+
+var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
+
+// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
+// RFC 2711 section 2.1.
+type IPv6RouterAlertOption struct {
+ Value IPv6RouterAlertValue
+}
+
+// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
+type IPv6RouterAlertValue uint16
+
+const (
+ // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
+ // Discovery message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertMLD IPv6RouterAlertValue = 0
+ // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
+ // defined in RFC 2711 section 2.1.
+ IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
+ // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
+ // Networks message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
+
+ // ipv6RouterAlertPayloadLength is the length of the Router Alert payload
+ // as defined in RFC 2711.
+ ipv6RouterAlertPayloadLength = 2
+
+ // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
+ // Router Alert option defined as 2n+0 in RFC 2711.
+ ipv6RouterAlertAlignmentRequirement = 2
+
+ // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
+ // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
+ // 2.1.
+ ipv6RouterAlertAlignmentOffsetRequirement = 0
+)
+
+// UnknownAction implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
+ return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
+ return ipv6RouterAlertHopByHopOptionIdentifier
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) length() uint8 {
+ return ipv6RouterAlertPayloadLength
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) alignment() (int, int) {
+ // From RFC 2711 section 2.1:
+ // Alignment requirement: 2n+0.
+ return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
+ binary.BigEndian.PutUint16(b, uint16(o.Value))
+ return ipv6RouterAlertPayloadLength
+}
+
+// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
+type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
+
+// Serialize serializes the provided list of IPv6 extension headers into b.
+//
+// Note, b must be of sufficient size to hold all the headers in s. See
+// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
+// serialized IPv6ExtHdrSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+//
+// Serialize takes the transportProtocol value to be used as the last extension
+// header's Next Header value and returns the header identifier of the first
+// serialized extension header and the total serialized length.
+func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
+ nextHeader := uint8(transportProtocol)
+ if len(s) == 0 {
+ return nextHeader, 0
+ }
+ var totalLength int
+ for i, h := range s[:len(s)-1] {
+ length := h.serializeInto(uint8(s[i+1].identifier()), b)
+ b = b[length:]
+ totalLength += length
+ }
+ totalLength += s[len(s)-1].serializeInto(nextHeader, b)
+ return uint8(s[0].identifier()), totalLength
+}
+
+// Length returns the total number of bytes required to serialize the extension
+// headers.
+func (s IPv6ExtHdrSerializer) Length() int {
+ var totalLength int
+ for _, h := range s {
+ totalLength += h.length()
+ }
+ return totalLength
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
index ab20c5f37..65adc6250 100644
--- a/pkg/tcpip/header/ipv6_extension_headers_test.go
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool
func TestIPv6UnknownExtHdrOption(t *testing.T) {
tests := []struct {
name string
- identifier IPv6ExtHdrOptionIndentifier
+ identifier IPv6ExtHdrOptionIdentifier
expectedUnknownAction IPv6OptionUnknownAction
}{
{
@@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
bytes: []byte{1, 3},
err: io.ErrUnexpectedEOF,
},
+ {
+ name: "Router alert without data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data and Pad1",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with extra data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with missing data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1},
+ err: io.ErrUnexpectedEOF,
+ },
}
check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
@@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) {
})
}
}
+
+var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil)
+
+// dummyHbHOptionSerializer provides a generic implementation of
+// IPv6SerializableHopByHopOption for use in tests.
+type dummyHbHOptionSerializer struct {
+ id IPv6ExtHdrOptionIdentifier
+ payload []byte
+ align int
+ alignOffset int
+}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) length() uint8 {
+ return uint8(len(s.payload))
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) alignment() (int, int) {
+ align := 1
+ if s.align != 0 {
+ align = s.align
+ }
+ return align, s.alignOffset
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 {
+ return uint8(copy(b, s.payload))
+}
+
+func TestIPv6HopByHopSerializer(t *testing.T) {
+ validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ dummy, ok := serializable.(*dummyHbHOptionSerializer)
+ if !ok {
+ t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable)
+ }
+ unknown, ok := deserialized.(*IPv6UnknownExtHdrOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{})
+ }
+ if dummy.id != unknown.Identifier {
+ t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id)
+ }
+ if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" {
+ t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff)
+ }
+ }
+ tests := []struct {
+ name string
+ nextHeader uint8
+ options []IPv6SerializableHopByHopOption
+ expect []byte
+ validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption)
+ }{
+ {
+ name: "single option",
+ nextHeader: 13,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 15,
+ payload: []byte{9, 8, 7, 6},
+ },
+ },
+ expect: []byte{13, 0, 15, 4, 9, 8, 7, 6},
+ validate: validateDummies,
+ },
+ {
+ name: "short option padN zero",
+ nextHeader: 88,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5},
+ },
+ },
+ expect: []byte{88, 0, 22, 2, 4, 5, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "short option pad1",
+ nextHeader: 11,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 33,
+ payload: []byte{1, 2, 3},
+ },
+ },
+ expect: []byte{11, 0, 33, 3, 1, 2, 3, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "long option padN",
+ nextHeader: 55,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 77,
+ payload: []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ },
+ },
+ expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 2n",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 2,
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 8n+1",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 8,
+ alignOffset: 1,
+ },
+ },
+ expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "no options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{},
+ expect: []byte{33, 0, 1, 4, 0, 0, 0, 0},
+ },
+ {
+ name: "Router Alert",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}},
+ expect: []byte{33, 0, 5, 2, 0, 0, 1, 0},
+ validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ routerAlert, ok := deserialized.(*IPv6RouterAlertOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized)
+ }
+ if routerAlert.Value != IPv6RouterAlertMLD {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD)
+ }
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6SerializableHopByHopExtHdr(test.options)
+ length := s.length()
+ if length != len(test.expect) {
+ t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect))
+ }
+ b := make([]byte, length)
+ for i := range b {
+ // Fill the buffer with ones to ensure all padding is correctly set.
+ b[i] = 0xFF
+ }
+ if got := s.serializeInto(test.nextHeader, b); got != length {
+ t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length)
+ }
+ if diff := cmp.Diff(test.expect, b); diff != "" {
+ t.Fatalf("serialization mismatch (-want +got):\n%s", diff)
+ }
+
+ // Deserialize the options and verify them.
+ optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit
+ iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter()
+ for _, testOpt := range test.options {
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ test.validate(t, testOpt, opt)
+ }
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if !done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ })
+ }
+}
+
+var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil)
+
+// dummyIPv6ExtHdrSerializer provides a generic implementation of
+// IPv6SerializableExtHdr for use in tests.
+//
+// The dummy header always carries the nextHeader value in the first byte.
+type dummyIPv6ExtHdrSerializer struct {
+ id IPv6ExtensionHeaderIdentifier
+ headerContents []byte
+}
+
+// identifier implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) length() int {
+ return len(s.headerContents) + 1
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int {
+ b[0] = nextHeader
+ return copy(b[1:], s.headerContents) + 1
+}
+
+func TestIPv6ExtHdrSerializer(t *testing.T) {
+ tests := []struct {
+ name string
+ headers []IPv6SerializableExtHdr
+ nextHeader tcpip.TransportProtocolNumber
+ expectSerialized []byte
+ expectNextHeader uint8
+ }{
+ {
+ name: "one header",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 15,
+ headerContents: []byte{1, 2, 3, 4},
+ },
+ },
+ nextHeader: TCPProtocolNumber,
+ expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4},
+ expectNextHeader: 15,
+ },
+ {
+ name: "two headers",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 22,
+ headerContents: []byte{1, 2, 3},
+ },
+ &dummyIPv6ExtHdrSerializer{
+ id: 23,
+ headerContents: []byte{4, 5, 6},
+ },
+ },
+ nextHeader: ICMPv6ProtocolNumber,
+ expectSerialized: []byte{
+ 23, 1, 2, 3,
+ byte(ICMPv6ProtocolNumber), 4, 5, 6,
+ },
+ expectNextHeader: 22,
+ },
+ {
+ name: "no headers",
+ headers: []IPv6SerializableExtHdr{},
+ nextHeader: UDPProtocolNumber,
+ expectSerialized: []byte{},
+ expectNextHeader: byte(UDPProtocolNumber),
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6ExtHdrSerializer(test.headers)
+ l := s.Length()
+ if got, want := l, len(test.expectSerialized); got != want {
+ t.Fatalf("got serialized length = %d, want = %d", got, want)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with garbage to make sure we're writing to all bytes.
+ b[i] = 0xFF
+ }
+ nextHeader, serializedLen := s.Serialize(test.nextHeader, b)
+ if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader {
+ t.Errorf(
+ "got s.Serialize(..) = (%d, %d), want = (%d, %d)",
+ nextHeader,
+ serializedLen,
+ test.expectNextHeader,
+ len(test.expectSerialized),
+ )
+ }
+ if diff := cmp.Diff(test.expectSerialized, b); diff != "" {
+ t.Errorf("serialization mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
index 018555a26..9d09f32eb 100644
--- a/pkg/tcpip/header/ipv6_fragment.go
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -27,12 +27,11 @@ const (
idV6 = 4
)
-// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
-// fields of a packet that needs to be encoded.
-type IPv6FragmentFields struct {
- // NextHeader is the "next header" field of an IPv6 fragment.
- NextHeader uint8
+var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
+// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
+// extension header as defined in RFC 8200 section 4.5.
+type IPv6SerializableFragmentExtHdr struct {
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
@@ -43,6 +42,29 @@ type IPv6FragmentFields struct {
Identification uint32
}
+// identifier implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6FragmentHeader
+}
+
+// length implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) length() int {
+ return IPv6FragmentHeaderSize
+}
+
+// serializeInto implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ // Prevent too many bounds checks.
+ _ = b[IPv6FragmentHeaderSize:]
+ binary.BigEndian.PutUint32(b[idV6:], h.Identification)
+ binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
+ b[nextHdrFrag] = nextHeader
+ if h.M {
+ b[more] |= ipv6FragmentExtHdrMFlagMask
+ }
+ return IPv6FragmentHeaderSize
+}
+
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
@@ -58,16 +80,6 @@ const (
IPv6FragmentHeaderSize = 8
)
-// Encode encodes all the fields of the ipv6 fragment.
-func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
- b[nextHdrFrag] = i.NextHeader
- binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
- if i.M {
- b[more] |= 1
- }
- binary.BigEndian.PutUint32(b[idV6:], i.Identification)
-}
-
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 426a873b1..e3fbd64f3 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -215,48 +215,6 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
}
}
-func TestIsV6UniqueLocalAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid Unique 1",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Valid Unique 2",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Link Local",
- addr: linkLocalAddr,
- expected: false,
- },
- {
- name: "Global",
- addr: globalAddr,
- expected: false,
- },
- {
- name: "IPv4",
- addr: "\x01\x02\x03\x04",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
-
func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
tests := []struct {
name string
@@ -346,7 +304,7 @@ func TestScopeForIPv6Address(t *testing.T) {
{
name: "Unique Local",
addr: uniqueLocalAddr1,
- scope: header.UniqueLocalScope,
+ scope: header.GlobalScope,
err: nil,
},
{
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 0efbfb22b..d9f8e3b35 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -31,7 +31,7 @@ type PacketInfo struct {
Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
- Route *stack.Route
+ Route stack.RouteInfo
}
// Notification is the interface for receiving notification from the packet
@@ -230,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket stores outbound packets into the channel.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
e.q.Write(p)
@@ -248,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets stores outbound packets into the channel.
func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
if !e.q.Write(p) {
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 9f2084eae..cb94cbea6 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // enable PACKET_FANOUT mode is the underlying socket is
- // of type AF_PACKET.
- const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ // Enable PACKET_FANOUT mode if the underlying socket is of type
+ // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
+ // prevent gvisor from receiving fragmented packets and the host does the
+ // reassembly on our behalf before delivering the fragments. This makes it
+ // hard to test fragmentation reassembly code in Netstack.
+ const fanoutType = unix.PACKET_FANOUT_HASH
fanoutArg := fanoutID | fanoutType<<16
if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index ce4da7230..a87abc6d6 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -323,9 +323,8 @@ func TestPreserveSrcAddress(t *testing.T) {
defer c.cleanup()
// Set LocalLinkAddress in route to the value of the bridged address.
- r := &stack.Route{
- LocalLinkAddress: baddr,
- }
+ var r stack.Route
+ r.LocalLinkAddress = baddr
r.ResolveWith(raddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -335,7 +334,7 @@ func TestPreserveSrcAddress(t *testing.T) {
ReserveHeaderBytes: header.EthernetMinimumSize,
Data: buffer.VectorisedView{},
})
- if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 3e4afcdad..b511d3a31 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -51,7 +51,8 @@ func TestInjectableEndpointDispatch(t *testing.T) {
Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
})
pkt.TransportHeader().Push(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
+ var packetRoute stack.Route
+ packetRoute.RemoteAddress = dstIP
endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
@@ -73,7 +74,8 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
Data: buffer.NewView(0).ToVectorisedView(),
})
pkt.TransportHeader().Push(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
+ var packetRoute stack.Route
+ packetRoute.RemoteAddress = dstIP
endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 27667f5f0..b7458b620 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -154,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
- newRoute := r.Clone()
- pkt.EgressRoute = newRoute
+ pkt.EgressRoute = r
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
@@ -178,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
for pkt := pkts.Front(); pkt != nil; {
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
nxt := pkt.Next()
- // Since qdisc can hold onto a packet for long we should Clone
- // the route here to ensure it doesn't get released while the
- // packet is still in our queue.
- newRoute := pkt.EgressRoute.Clone()
- pkt.EgressRoute = newRoute
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
index eb5abb906..45adcbccb 100644
--- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
+++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
@@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
+ s.EgressRoute.Acquire()
q.list.PushBack(s)
q.used++
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 7131392cc..dd2e1a125 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -340,9 +340,8 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
// Set both remote and local link address in route.
- r := stack.Route{
- LocalLinkAddress: newLocalLinkAddress,
- }
+ var r stack.Route
+ r.LocalLinkAddress = newLocalLinkAddress
r.ResolveWith(remoteLinkAddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 8d9a91020..1a2cc39eb 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -263,7 +263,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- if parse.ARP(pkt) {
+ if !parse.ARP(pkt) {
return
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index a364c5801..bfac358f4 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
// If the packet does not already have link layer header, and the route
// does not exist, we can't compute it. This is possibly a raw packet, tun
// device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" {
+ if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 {
return nil, false
}
@@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader().View().IsEmpty() {
- d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
vv.AppendView(info.Pkt.LinkHeader().View())
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 0fb373612..a25cba513 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -441,9 +441,8 @@ func (*testInterface) Promiscuous() bool {
}
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- r := stack.Route{
- NetProto: protocol,
- }
+ var r stack.Route
+ r.NetProto = protocol
r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -557,8 +556,8 @@ func TestLinkAddressRequest(t *testing.T) {
t.Fatal("expected to send a link address request")
}
- if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
}
rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index d8e4a3b54..429af69ee 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -18,7 +18,6 @@ go_template_instance(
go_library(
name = "fragmentation",
srcs = [
- "frag_heap.go",
"fragmentation.go",
"reassembler.go",
"reassembler_list.go",
@@ -38,7 +37,6 @@ go_test(
name = "fragmentation_test",
size = "small",
srcs = [
- "frag_heap_test.go",
"fragmentation_test.go",
"reassembler_test.go",
],
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
deleted file mode 100644
index 0b570d25a..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-type fragment struct {
- offset uint16
- vv buffer.VectorisedView
-}
-
-type fragHeap []fragment
-
-func (h *fragHeap) Len() int {
- return len(*h)
-}
-
-func (h *fragHeap) Less(i, j int) bool {
- return (*h)[i].offset < (*h)[j].offset
-}
-
-func (h *fragHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *fragHeap) Push(x interface{}) {
- *h = append(*h, x.(fragment))
-}
-
-func (h *fragHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[:n-1]
- return x
-}
-
-// reassamble empties the heap and returns a VectorisedView
-// containing a reassambled version of the fragments inside the heap.
-func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
- curr := heap.Pop(h).(fragment)
- views := curr.vv.Views()
- size := curr.vv.Size()
-
- if curr.offset != 0 {
- return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
- }
-
- for h.Len() > 0 {
- curr := heap.Pop(h).(fragment)
- if int(curr.offset) < size {
- curr.vv.TrimFront(size - int(curr.offset))
- } else if int(curr.offset) > size {
- return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
- }
- size += curr.vv.Size()
- views = append(views, curr.vv.Views()...)
- }
- return buffer.NewVectorisedView(size, views), nil
-}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
deleted file mode 100644
index 9ececcb9f..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap_test.go
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-var reassambleTestCases = []struct {
- comment string
- in []fragment
- want buffer.VectorisedView
-}{
- {
- comment: "Non-overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Non-overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Duplicated packets",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(1, "0"),
- },
- {
- comment: "Overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(2, "01")},
- {offset: 1, vv: vv(2, "12")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(2, "12")},
- {offset: 0, vv: vv(2, "01")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping subset in-order",
- in: []fragment{
- {offset: 0, vv: vv(3, "012")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(3, "012"),
- },
- {
- comment: "Overlapping subset out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(3, "012")},
- },
- want: vv(3, "012"),
- },
-}
-
-func TestReassamble(t *testing.T) {
- for _, c := range reassambleTestCases {
- t.Run(c.comment, func(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- for _, f := range c.in {
- heap.Push(&h, f)
- }
- got, err := h.reassemble()
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
- }
- })
- }
-}
-
-func TestReassambleFailsForNonZeroOffset(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when the first packet had offset != 0")
- }
-}
-
-func TestReassambleFailsForHoles(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
- heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when there was a hole in the packet")
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index d31296a41..1af87d713 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -53,6 +53,10 @@ var (
// ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
// with another one.
ErrFragmentOverlap = errors.New("overlapping fragments")
+
+ // ErrFragmentConflict indicates that, during reassembly, some fragments are
+ // in conflict with one another.
+ ErrFragmentConflict = errors.New("conflicting fragments")
)
// FragmentID is the identifier for a fragment.
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 04072d966..9b20bb1d8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -15,9 +15,8 @@
package fragmentation
import (
- "container/heap"
- "fmt"
"math"
+ "sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -29,6 +28,8 @@ type hole struct {
first uint16
last uint16
filled bool
+ final bool
+ data buffer.View
}
type reassembler struct {
@@ -39,7 +40,6 @@ type reassembler struct {
mu sync.Mutex
holes []hole
filled int
- heap fragHeap
done bool
creationTime int64
pkt *stack.PacketBuffer
@@ -48,51 +48,71 @@ type reassembler struct {
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
- holes: make([]hole, 0, 16),
- heap: make(fragHeap, 0, 8),
creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
last: math.MaxUint16,
filled: false,
+ final: true,
})
return r
}
-// updateHoles updates the list of holes for an incoming fragment. It returns
-// true if the fragment fits, it is not a duplicate and it does not overlap with
-// another fragment.
-//
-// For IPv6, overlaps with an existing fragment are explicitly forbidden by
-// RFC 8200 section 4.5:
-// If any of the fragments being reassembled overlap with any other fragments
-// being reassembled for the same packet, reassembly of that packet must be
-// abandoned and all the fragments that have been received for that packet
-// must be discarded, and no ICMP error messages should be sent.
-//
-// It is not explicitly forbidden for IPv4, but to keep parity with Linux we
-// disallow it as well:
-// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
-func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return buffer.VectorisedView{}, 0, false, 0, nil
+ }
+
+ var holeFound bool
+ var consumed int
for i := range r.holes {
currentHole := &r.holes[i]
- if currentHole.filled || last < currentHole.first || currentHole.last < first {
+ if last < currentHole.first || currentHole.last < first {
continue
}
-
+ // For IPv6, overlaps with an existing fragment are explicitly forbidden by
+ // RFC 8200 section 4.5:
+ // If any of the fragments being reassembled overlap with any other
+ // fragments being reassembled for the same packet, reassembly of that
+ // packet must be abandoned and all the fragments that have been received
+ // for that packet must be discarded, and no ICMP error messages should be
+ // sent.
+ //
+ // It is not explicitly forbidden for IPv4, but to keep parity with Linux we
+ // disallow it as well:
+ // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
if first < currentHole.first || currentHole.last < last {
// Incoming fragment only partially fits in the free hole.
- return false, ErrFragmentOverlap
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap
+ }
+ if !more {
+ if !currentHole.final || currentHole.filled && currentHole.last != last {
+ // We have another final fragment, which does not perfectly overlap.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ }
}
- r.filled++
+ holeFound = true
+ if currentHole.filled {
+ // Incoming fragment is a duplicate.
+ continue
+ }
+
+ // We are populating the current hole with the payload and creating a new
+ // hole for any unfilled ranges on either end.
if first > currentHole.first {
r.holes = append(r.holes, hole{
first: currentHole.first,
last: first - 1,
filled: false,
+ final: false,
})
}
if last < currentHole.last && more {
@@ -100,39 +120,22 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) {
first: last + 1,
last: currentHole.last,
filled: false,
+ final: currentHole.final,
})
+ currentHole.final = false
}
+ v := pkt.Data.ToOwnedView()
+ consumed = v.Size()
+ r.size += consumed
// Update the current hole to precisely match the incoming fragment.
r.holes[i] = hole{
first: first,
last: last,
filled: true,
+ final: currentHole.final,
+ data: v,
}
- return true, nil
- }
-
- // Incoming fragment is a duplicate/subset, or its offset comes after the end
- // of the reassembled payload.
- return false, nil
-}
-
-func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.done {
- // A concurrent goroutine might have already reassembled
- // the packet and emptied the heap while this goroutine
- // was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, 0, false, 0, nil
- }
-
- used, err := r.updateHoles(first, last, more)
- if err != nil {
- return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err)
- }
-
- var consumed int
- if used {
+ r.filled++
// For IPv6, it is possible to have different Protocol values between
// fragments of a packet (because, unlike IPv4, the Protocol is not used to
// identify a fragment). In this case, only the Protocol of the first
@@ -145,22 +148,30 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
r.pkt = pkt
r.proto = proto
}
- vv := pkt.Data
- // We store the incoming packet only if it filled some holes.
- heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
- consumed = vv.Size()
- r.size += consumed
+
+ break
+ }
+ if !holeFound {
+ // Incoming fragment is beyond end.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
}
// Check if all the holes have been filled and we are ready to reassemble.
if r.filled < len(r.holes) {
return buffer.VectorisedView{}, 0, false, consumed, nil
}
- res, err := r.heap.reassemble()
- if err != nil {
- return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err)
+
+ sort.Slice(r.holes, func(i, j int) bool {
+ return r.holes[i].first < r.holes[j].first
+ })
+
+ var size int
+ views := make([]buffer.View, 0, len(r.holes))
+ for _, hole := range r.holes {
+ views = append(views, hole.data)
+ size += hole.data.Size()
}
- return res, r.proto, true, consumed, nil
+ return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index cee3063b1..2ff03eeeb 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -19,105 +19,156 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-type updateHolesParams struct {
+type processParams struct {
first uint16
last uint16
more bool
- wantUsed bool
+ pkt *stack.PacketBuffer
+ wantDone bool
wantError error
}
-func TestUpdateHoles(t *testing.T) {
+func TestReassemblerProcess(t *testing.T) {
+ const proto = 99
+
+ v := func(size int) buffer.View {
+ payload := buffer.NewView(size)
+ for i := 1; i < size; i++ {
+ payload[i] = uint8(i) * 3
+ }
+ return payload
+ }
+
+ pkt := func(size int) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v(size).ToVectorisedView(),
+ })
+ }
+
var tests = []struct {
name string
- params []updateHolesParams
+ params []processParams
want []hole
}{
{
name: "No fragments",
params: nil,
- want: []hole{{first: 0, last: math.MaxUint16, filled: false}},
+ want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
},
{
name: "One fragment at beginning",
- params: []updateHolesParams{{first: 0, last: 1, more: true, wantUsed: true, wantError: nil}},
+ params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
want: []hole{
- {first: 0, last: 1, filled: true},
- {first: 2, last: math.MaxUint16, filled: false},
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: math.MaxUint16, filled: false, final: true},
},
},
{
name: "One fragment in the middle",
- params: []updateHolesParams{{first: 1, last: 2, more: true, wantUsed: true, wantError: nil}},
+ params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
want: []hole{
- {first: 1, last: 2, filled: true},
- {first: 0, last: 0, filled: false},
- {first: 3, last: math.MaxUint16, filled: false},
+ {first: 1, last: 2, filled: true, final: false, data: v(2)},
+ {first: 0, last: 0, filled: false, final: false},
+ {first: 3, last: math.MaxUint16, filled: false, final: true},
},
},
{
name: "One fragment at the end",
- params: []updateHolesParams{{first: 1, last: 2, more: false, wantUsed: true, wantError: nil}},
+ params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
want: []hole{
- {first: 1, last: 2, filled: true},
+ {first: 1, last: 2, filled: true, final: true, data: v(2)},
{first: 0, last: 0, filled: false},
},
},
{
name: "One fragment completing a packet",
- params: []updateHolesParams{{first: 0, last: 1, more: false, wantUsed: true, wantError: nil}},
+ params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
want: []hole{
- {first: 0, last: 1, filled: true},
+ {first: 0, last: 1, filled: true, final: true, data: v(2)},
},
},
{
name: "Two fragments completing a packet",
- params: []updateHolesParams{
- {first: 0, last: 1, more: true, wantUsed: true, wantError: nil},
- {first: 2, last: 3, more: false, wantUsed: true, wantError: nil},
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
},
want: []hole{
- {first: 0, last: 1, filled: true},
- {first: 2, last: 3, filled: true},
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
},
},
{
name: "Two fragments completing a packet with a duplicate",
- params: []updateHolesParams{
- {first: 0, last: 1, more: true, wantUsed: true, wantError: nil},
- {first: 0, last: 1, more: true, wantUsed: false, wantError: nil},
- {first: 2, last: 3, more: false, wantUsed: true, wantError: nil},
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
+ },
+ {
+ name: "Two fragments completing a packet with a partial duplicate",
+ params: []processParams{
+ {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
+ {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
},
want: []hole{
- {first: 0, last: 1, filled: true},
- {first: 2, last: 3, filled: true},
+ {first: 0, last: 3, filled: true, final: false, data: v(4)},
+ {first: 4, last: 5, filled: true, final: true, data: v(2)},
},
},
{
name: "Two overlapping fragments",
- params: []updateHolesParams{
- {first: 0, last: 10, more: true, wantUsed: true, wantError: nil},
- {first: 5, last: 15, more: false, wantUsed: false, wantError: ErrFragmentOverlap},
- {first: 11, last: 15, more: false, wantUsed: true, wantError: nil},
+ params: []processParams{
+ {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
+ {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true, final: false, data: v(11)},
+ {first: 11, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "Two final fragments with different ends",
+ params: []processParams{
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 10, last: 14, filled: true, final: true, data: v(5)},
+ {first: 0, last: 9, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
},
want: []hole{
- {first: 0, last: 10, filled: true},
- {first: 11, last: 15, filled: true},
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
},
},
{
- name: "Out of bounds fragment",
- params: []updateHolesParams{
- {first: 0, last: 10, more: true, wantUsed: true, wantError: nil},
- {first: 11, last: 15, more: false, wantUsed: true, wantError: nil},
- {first: 16, last: 20, more: false, wantUsed: false, wantError: nil},
+ name: "Two final fragments - duplicate, with different ends",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
},
want: []hole{
- {first: 0, last: 10, filled: true},
- {first: 11, last: 15, filled: true},
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
},
},
}
@@ -126,9 +177,9 @@ func TestUpdateHoles(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
r := newReassembler(FragmentID{}, &faketime.NullClock{})
for _, param := range test.params {
- used, err := r.updateHoles(param.first, param.last, param.more)
- if used != param.wantUsed || err != param.wantError {
- t.Errorf("got r.updateHoles(%d, %d, %t) = (%t, %v), want = (%t, %v)", param.first, param.last, param.more, used, err, param.wantUsed, param.wantError)
+ _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
+ if done != param.wantDone || err != param.wantError {
+ t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
}
}
if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
index 6ca200b48..ca1247c1e 100644
--- a/pkg/tcpip/network/ip/BUILD
+++ b/pkg/tcpip/network/ip/BUILD
@@ -18,6 +18,7 @@ go_test(
srcs = ["generic_multicast_protocol_test.go"],
deps = [
":ip",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/faketime",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index e308550c4..f2f0e069c 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -30,6 +30,23 @@ type hostState int
// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
// and MLDv1, IGMPv2 terminology will be used.
+//
+// ______________receive query______________
+// | |
+// | _____send or receive report_____ |
+// | | | |
+// V | V |
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+ |
+// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | -
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+
+// | ^ | ^ | ^ | ^
+// | | | | | | | |
+// ---------- ------- ---------- -------------
+// initialize new send inital fail to send send or receive
+// group membership report delayed report report
+//
+// Not shown in the diagram above, but any state may transition into the non
+// member state when a group is left.
const (
// nonMember is the "'Non-Member' state, when the host does not belong to the
// group on the interface. This is the initial state for all memberships on
@@ -41,6 +58,15 @@ const (
// but without advertising the membership to the network.
nonMember hostState = iota
+ // pendingMember is a newly joined member that is waiting to successfully send
+ // the initial set of reports.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the initial report needs to be sent.
+ //
+ // MAY NOT transition to the idle member state from this state.
+ pendingMember
+
// delayingMember is the "'Delaying Member' state, when the host belongs to
// the group on the interface and has a report delay timer running for that
// membership."
@@ -48,6 +74,16 @@ const (
// 'Delaying Listener' is the MLDv1 term used to describe this state.
delayingMember
+ // queuedDelayingMember is a delayingMember that failed to send a report after
+ // its delayed report timer fired. Hosts in this state are waiting to attempt
+ // retransmission of the delayed report.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the delayed report needs to be sent.
+ //
+ // May transition to idle member if a report is received for a group.
+ queuedDelayingMember
+
// idleMember is the "Idle Member" state, when the host belongs to the group
// on the interface and does not have a report delay timer running for that
// membership.
@@ -56,6 +92,17 @@ const (
idleMember
)
+func (s hostState) isDelayingMember() bool {
+ switch s {
+ case nonMember, pendingMember, idleMember:
+ return false
+ case delayingMember, queuedDelayingMember:
+ return true
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", s))
+ }
+}
+
// multicastGroupState holds the Generic Multicast Protocol state for a
// multicast group.
type multicastGroupState struct {
@@ -84,17 +131,6 @@ type multicastGroupState struct {
// GenericMulticastProtocolOptions holds options for the generic multicast
// protocol.
type GenericMulticastProtocolOptions struct {
- // Enabled indicates whether the generic multicast protocol will be
- // performed.
- //
- // When enabled, the protocol may transmit report and leave messages when
- // joining and leaving multicast groups respectively, and handle incoming
- // packets.
- //
- // When disabled, the protocol will still keep track of locally joined groups,
- // it just won't transmit and handle packets, or update groups' state.
- Enabled bool
-
// Rand is the source of random numbers.
Rand *rand.Rand
@@ -123,8 +159,22 @@ type GenericMulticastProtocolOptions struct {
// MulticastGroupProtocol is a multicast group protocol whose core state machine
// can be represented by GenericMulticastProtocolState.
type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
// SendReport sends a multicast report for the specified group address.
- SendReport(groupAddress tcpip.Address) *tcpip.Error
+ //
+ // Returns false if the caller should queue the report to be sent later. Note,
+ // returning false does not mean that the receiver hit an error.
+ SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
// SendLeave sends a multicast leave for the specified group address.
SendLeave(groupAddress tcpip.Address) *tcpip.Error
@@ -138,76 +188,119 @@ type MulticastGroupProtocol interface {
// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
//
+// Callers must synchronize accesses to the generic multicast protocol state;
+// GenericMulticastProtocolState obtains no locks in any of its methods. The
+// only exception to this is GenericMulticastProtocolState's timer/job callbacks
+// which will obtain the lock provided to the GenericMulticastProtocolState when
+// it is initialized.
+//
// GenericMulticastProtocolState.Init MUST be called before calling any of
// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
type GenericMulticastProtocolState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
opts GenericMulticastProtocolOptions
- mu struct {
- sync.RWMutex
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
- // memberships holds group addresses and their associated state.
- memberships map[tcpip.Address]multicastGroupState
- }
+ // protocolMU is the mutex used to protect the protocol.
+ protocolMU *sync.RWMutex
}
// Init initializes the Generic Multicast Protocol state.
-func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) {
- g.mu.Lock()
- defer g.mu.Unlock()
- g.opts = opts
- g.mu.memberships = make(map[tcpip.Address]multicastGroupState)
+//
+// Must only be called once for the lifetime of g; Init will panic if it is
+// called twice.
+//
+// The GenericMulticastProtocolState will only grab the lock when timers/jobs
+// fire.
+//
+// Note: the methods on opts.Protocol will always be called while protocolMU is
+// held.
+func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
+ if g.memberships != nil {
+ panic("attempted to initialize generic membership protocol state twice")
+ }
+
+ *g = GenericMulticastProtocolState{
+ opts: opts,
+ memberships: make(map[tcpip.Address]multicastGroupState),
+ protocolMU: protocolMU,
+ }
}
-// MakeAllNonMember transitions all groups to the non-member state.
+// MakeAllNonMemberLocked transitions all groups to the non-member state.
//
// The groups will still be considered joined locally.
-func (g *GenericMulticastProtocolState) MakeAllNonMember() {
- if !g.opts.Enabled {
+//
+// MUST be called when the multicast group protocol is disabled.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
+ if !g.opts.Protocol.Enabled() {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
- for groupAddress, info := range g.mu.memberships {
+ for groupAddress, info := range g.memberships {
g.transitionToNonMemberLocked(groupAddress, &info)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
-// InitializeGroups initializes each group, as if they were newly joined but
-// without affecting the groups' join count.
+// InitializeGroupsLocked initializes each group, as if they were newly joined
+// but without affecting the groups' join count.
//
// Must only be called after calling MakeAllNonMember as a group should not be
// initialized while it is not in the non-member state.
-func (g *GenericMulticastProtocolState) InitializeGroups() {
- if !g.opts.Enabled {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
+ if !g.opts.Protocol.Enabled() {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
- for groupAddress, info := range g.mu.memberships {
+ for groupAddress, info := range g.memberships {
g.initializeNewMemberLocked(groupAddress, &info)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
-// JoinGroup handles joining a new group.
+// SendQueuedReportsLocked attempts to send reports for groups that failed to
+// send reports during their last attempt.
//
-// If dontInitialize is true, the group will be not be initialized and will be
-// left in the non-member state - no packets will be sent for it until it is
-// initialized via InitializeGroups.
-func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) {
- g.mu.Lock()
- defer g.mu.Unlock()
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
+ for groupAddress, info := range g.memberships {
+ switch info.state {
+ case nonMember, delayingMember, idleMember:
+ case pendingMember:
+ // pendingMembers failed to send their initial unsolicited report so try
+ // to send the report and queue the extra unsolicited reports.
+ g.maybeSendInitialReportLocked(groupAddress, &info)
+ case queuedDelayingMember:
+ // queuedDelayingMembers failed to send their delayed reports so try to
+ // send the report and transition them to the idle state.
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", info.state))
+ }
+ g.memberships[groupAddress] = info
+ }
+}
- if info, ok := g.mu.memberships[groupAddress]; ok {
+// JoinGroupLocked handles joining a new group.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
+ if info, ok := g.memberships[groupAddress]; ok {
// The group has already been joined.
info.joins++
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
return
}
@@ -217,41 +310,43 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do
// The state will be updated below, if required.
state: nonMember,
lastToSendReport: false,
- delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() {
- info, ok := g.mu.memberships[groupAddress]
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
+ info, ok := g.memberships[groupAddress]
if !ok {
panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
}
- info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
- info.state = idleMember
- g.mu.memberships[groupAddress] = info
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
}),
}
- if !dontInitialize && g.opts.Enabled {
+ if g.opts.Protocol.Enabled() {
g.initializeNewMemberLocked(groupAddress, &info)
}
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
-// IsLocallyJoined returns true if the group is locally joined.
-func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool {
- g.mu.RLock()
- defer g.mu.RUnlock()
- _, ok := g.mu.memberships[groupAddress]
+// IsLocallyJoinedRLocked returns true if the group is locally joined.
+//
+// Precondition: g.protocolMU must be read locked.
+func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool {
+ _, ok := g.memberships[groupAddress]
return ok
}
-// LeaveGroup handles leaving the group.
+// LeaveGroupLocked handles leaving the group.
//
// Returns false if the group is not currently joined.
-func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool {
- g.mu.Lock()
- defer g.mu.Unlock()
-
- info, ok := g.mu.memberships[groupAddress]
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool {
+ info, ok := g.memberships[groupAddress]
if !ok {
return false
}
@@ -262,30 +357,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b
info.joins--
if info.joins != 0 {
// If we still have outstanding joins, then do nothing further.
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
return true
}
g.transitionToNonMemberLocked(groupAddress, &info)
- delete(g.mu.memberships, groupAddress)
+ delete(g.memberships, groupAddress)
return true
}
-// HandleQuery handles a query message with the specified maximum response time.
+// HandleQueryLocked handles a query message with the specified maximum response
+// time.
//
// If the group address is unspecified, then reports will be scheduled for all
// joined groups.
//
// Report(s) will be scheduled to be sent after a random duration between 0 and
// the maximum response time.
-func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) {
- if !g.opts.Enabled {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Protocol.Enabled() {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
// As per RFC 2236 section 2.4 (for IGMPv2),
//
// In a Membership Query message, the group address field is set to zero
@@ -299,28 +394,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address,
// when sending a Multicast-Address-Specific Query.
if groupAddress.Unspecified() {
// This is a general query as the group address is unspecified.
- for groupAddress, info := range g.mu.memberships {
+ for groupAddress, info := range g.memberships {
g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
- } else if info, ok := g.mu.memberships[groupAddress]; ok {
+ } else if info, ok := g.memberships[groupAddress]; ok {
g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
-// HandleReport handles a report message.
+// HandleReportLocked handles a report message.
//
// If the report is for a joined group, any active delayed report will be
// cancelled and the host state for the group transitions to idle.
-func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) {
- if !g.opts.Enabled {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
+ if !g.opts.Protocol.Enabled() {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
// As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
//
// If the host receives another host's Report (version 1 or 2) while it has
@@ -333,23 +427,23 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address)
// multicast address while it has a timer running for that same address
// on that interface, it stops its timer and does not send a Report for
// that address, thus suppressing duplicate reports on the link.
- if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember {
+ if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
info.delayedReportJob.Cancel()
info.lastToSendReport = false
info.state = idleMember
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
// initializeNewMemberLocked initializes a new group membership.
//
-// Precondition: g.mu must be locked.
+// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
if info.state != nonMember {
- panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state))
+ panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state))
}
- info.state = idleMember
+ info.lastToSendReport = false
if groupAddress == g.opts.AllNodesAddress {
// As per RFC 2236 section 6 page 10 (for IGMPv2),
@@ -365,9 +459,25 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
// case. The node starts in Idle Listener state for that address on
// every interface, never transitions to another state, and never sends
// a Report or Done for that address.
+ info.state = idleMember
return
}
+ info.state = pendingMember
+ g.maybeSendInitialReportLocked(groupAddress, info)
+}
+
+// maybeSendInitialReportLocked attempts to start transmission of the initial
+// set of reports after newly joining a group.
+//
+// Host must be in pending member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != pendingMember {
+ panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
// As per RFC 2236 section 3 page 5 (for IGMPv2),
//
// When a host joins a multicast group, it should immediately transmit an
@@ -385,13 +495,35 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
//
// TODO(gvisor.dev/issue/4901): Support a configurable number of initial
// unsolicited reports.
- info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
- g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ }
+}
+
+// maybeSendDelayedReportLocked attempts to send the delayed report.
+//
+// Host must be in pending, delaying or queued delaying member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if !info.state.isDelayingMember() {
+ panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ info.state = idleMember
+ } else {
+ info.state = queuedDelayingMember
+ }
}
// maybeSendLeave attempts to send a leave message.
func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
- if !g.opts.Enabled || !lastToSendReport {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
return
}
@@ -465,7 +597,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres
// transitionToNonMemberLocked transitions the given multicast group the the
// non-member/listener state.
//
-// Precondition: e.mu must be locked.
+// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
if info.state == nonMember {
return
@@ -479,7 +611,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress
// setDelayTimerForAddressRLocked sets timer to send a delay report.
//
-// Precondition: g.mu MUST be read locked.
+// Precondition: g.protocolMU MUST be read locked.
func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
if info.state == nonMember {
return
@@ -517,6 +649,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
// TODO: Reset the timer if time remaining is greater than maxResponseTime.
return
}
+
info.state = delayingMember
info.delayedReportJob.Cancel()
info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index 670be30d4..85593f211 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/network/ip"
@@ -36,42 +37,178 @@ const (
var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
-type mockMulticastGroupProtocol struct {
+type mockMulticastGroupProtocolProtectedFields struct {
+ sync.RWMutex
+
+ genericMulticastGroup ip.GenericMulticastProtocolState
sendReportGroupAddrCount map[tcpip.Address]int
sendLeaveGroupAddrCount map[tcpip.Address]int
+ makeQueuePackets bool
+ disabled bool
}
-func (m *mockMulticastGroupProtocol) init() {
- m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
- m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+type mockMulticastGroupProtocol struct {
+ t *testing.T
+
+ mu mockMulticastGroupProtocolProtectedFields
}
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error {
- m.sendReportGroupAddrCount[groupAddress]++
- return nil
+func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.initLocked()
+ opts.Protocol = m
+ m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
+}
+
+func (m *mockMulticastGroupProtocol) initLocked() {
+ m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.disabled = !v
+}
+
+func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.makeQueuePackets = v
}
+func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.JoinGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleReportLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
+}
+
+func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) makeAllNonMember() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initializeGroups() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.InitializeGroupsLocked()
+}
+
+func (m *mockMulticastGroupProtocol) sendQueuedReports() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.SendQueuedReportsLocked()
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be read locked.
+func (m *mockMulticastGroupProtocol) Enabled() bool {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
+ }
+
+ return !m.mu.disabled
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+
+ m.mu.sendReportGroupAddrCount[groupAddress]++
+ return !m.mu.makeQueuePackets, nil
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
- m.sendLeaveGroupAddrCount[groupAddress]++
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+
+ m.mu.sendLeaveGroupAddrCount[groupAddress]++
return nil
}
-func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
- sendReportGroupAddressesMap := make(map[tcpip.Address]int)
+func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendReportGroupAddrCount := make(map[tcpip.Address]int)
for _, a := range sendReportGroupAddresses {
- sendReportGroupAddressesMap[a] = 1
+ sendReportGroupAddrCount[a] = 1
}
- sendLeaveGroupAddressesMap := make(map[tcpip.Address]int)
+ sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
for _, a := range sendLeaveGroupAddresses {
- sendLeaveGroupAddressesMap[a] = 1
+ sendLeaveGroupAddrCount[a] = 1
}
- diff := cmp.Diff(mockMulticastGroupProtocol{
- sendReportGroupAddrCount: sendReportGroupAddressesMap,
- sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap,
- }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{}))
- mgp.init()
+ diff := cmp.Diff(
+ &mockMulticastGroupProtocol{
+ mu: mockMulticastGroupProtocolProtectedFields{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
+ },
+ m,
+ cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
+ // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
+ cmp.FilterPath(
+ func(p cmp.Path) bool {
+ switch p.Last().String() {
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
+ return true
+ }
+ return false
+ },
+ cmp.Ignore(),
+ ),
+ )
+ m.initLocked()
return diff
}
@@ -95,36 +232,34 @@ func TestJoinGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
- Enabled: true,
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(0)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr2,
})
// Joining a group should send a report immediately and another after
// a random interval between 0 and the maximum unsolicited report delay.
- g.JoinGroup(test.addr, false /* dontInitialize */)
+ mgp.joinGroup(test.addr)
if test.shouldSendReports {
- if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -151,40 +286,42 @@ func TestLeaveGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
- Enabled: true,
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(1)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr2,
})
- g.JoinGroup(test.addr, false /* dontInitialize */)
+ mgp.joinGroup(test.addr)
if test.shouldSendMessages {
- if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Leaving a group should send a leave report immediately and cancel any
// delayed reports.
- if !g.LeaveGroup(test.addr) {
- t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr)
+ {
+
+ if !mgp.leaveGroup(test.addr) {
+ t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
+ }
}
if test.shouldSendMessages {
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -226,45 +363,43 @@ func TestHandleReport(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
- Enabled: true,
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(2)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr3,
})
- g.JoinGroup(addr1, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr2, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr3, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a report for a group we have a timer scheduled for should
// cancel our delayed report timer for the group.
- g.HandleReport(test.reportAddr)
+ mgp.handleReport(test.reportAddr)
if len(test.expectReportsFor) != 0 {
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -312,49 +447,47 @@ func TestHandleQuery(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
- Enabled: true,
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr3,
})
- g.JoinGroup(addr1, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr2, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr3, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a query should make us schedule a new delayed report if it
// is a query directed at us or a general query.
- g.HandleQuery(test.queryAddr, test.maxDelay)
+ mgp.handleQuery(test.queryAddr, test.maxDelay)
if len(test.expectReportsFor) != 0 {
clock.Advance(test.maxDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -362,133 +495,139 @@ func TestHandleQuery(t *testing.T) {
}
func TestJoinCount(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
- Enabled: true,
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(4)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: time.Second,
})
// Set the join count to 2 for a group.
- g.JoinGroup(addr1, false /* dontInitialize */)
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
}
// Only the first join should trigger a report to be sent.
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr1, false /* dontInitialize */)
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
}
// Group should still be considered joined after leaving once.
- if !g.LeaveGroup(addr1) {
- t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
}
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
}
// A leave report should only be sent once the join count reaches 0.
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
}
// Leaving once more should actually remove us from the group.
- if !g.LeaveGroup(addr1) {
- t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
}
- if g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
}
// Group should no longer be joined so we should not have anything to
// leave.
- if g.LeaveGroup(addr1) {
- t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1)
+ if mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
}
- if g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
func TestMakeAllNonMemberAndInitialize(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
- Enabled: true,
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
- Protocol: &mgp,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
AllNodesAddress: addr3,
})
- g.JoinGroup(addr1, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr2, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr3, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should send the leave reports for each but still consider them locally
// joined.
- g.MakeAllNonMember()
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.makeAllNonMember()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
for _, group := range []tcpip.Address{addr1, addr2, addr3} {
- if !g.IsLocallyJoined(group) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group)
+ if !mgp.isLocallyJoined(group) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
}
}
// Should send the initial set of unsolcited reports.
- g.InitializeGroups()
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.initializeGroups()
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
@@ -496,81 +635,172 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) {
// TestGroupStateNonMember tests that groups do not send packets when in the
// non-member state, but are still considered locally joined.
func TestGroupStateNonMember(t *testing.T) {
- tests := []struct {
- name string
- enabled bool
- dontInitialize bool
- }{
- {
- name: "Disabled",
- enabled: false,
- dontInitialize: false,
- },
- {
- name: "Keep non-member",
- enabled: true,
- dontInitialize: true,
- },
- {
- name: "disabled and Keep non-member",
- enabled: false,
- dontInitialize: true,
- },
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+ mgp.setEnabled(false)
+
+ // Joining groups should not send any reports.
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
- clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
- Enabled: test.enabled,
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- Protocol: &mgp,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
+ // Receiving a query should not send any reports.
+ mgp.handleQuery(addr1, time.Nanosecond)
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- g.JoinGroup(addr1, test.dontInitialize)
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
- }
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ // Leaving groups should not send any leave messages.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- g.JoinGroup(addr2, test.dontInitialize)
- if !g.IsLocallyJoined(addr2) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2)
- }
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
- g.HandleQuery(addr1, time.Nanosecond)
- clock.Advance(time.Nanosecond)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+func TestQueuedPackets(t *testing.T) {
+ clock := faketime.NewManualClock()
+ mgp := mockMulticastGroupProtocol{t: t}
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
- if !g.LeaveGroup(addr2) {
- t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2)
- }
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
- }
- if g.IsLocallyJoined(addr2) {
- t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2)
- }
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
+ // Joining should trigger a SendReport, but mgp should report that we did not
+ // send the packet.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
- clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
+ // The delayed report timer should have been cancelled since we did not send
+ // the initial report earlier.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send (we should be idle).
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query but mock being unable to send reports again.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to send reports again - we should have a packet queued to
+ // send.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query again, but mock being unable to send reports.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report should should transition us into the idle member state,
+ // even if we had a packet queued. We should no longer have any packets to
+ // send.
+ mgp.handleReport(addr1)
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // When we fail to send the initial set of reports, incoming reports should
+ // not affect a newly joined group's reports from being sent.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.handleReport(addr2)
+ // Attempting to send queued reports while still unable to send reports should
+ // not change the host state.
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a314dd386..3005973d7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -344,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -619,11 +619,11 @@ func TestReceive(t *testing.T) {
view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadLen,
- NextHeader: 10,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
+ PayloadLength: payloadLen,
+ TransportProtocol: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -993,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Create the outer IPv6 header.
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIPv6Addr,
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 20,
+ SrcAddr: outerSrcAddr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -1007,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp.SetIdent(0xdead)
icmp.SetSequence(0xbeef)
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: localIPv6Addr,
- DstAddr: remoteIPv6Addr,
- })
-
+ var extHdrs header.IPv6ExtHdrSerializer
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
- ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
- frag.Encode(&header.IPv6FragmentFields{
- NextHeader: 10,
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
FragmentOffset: *c.fragmentOffset,
M: true,
Identification: 0x12345678,
})
}
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
+ ExtensionHeaders: extHdrs,
+ })
+
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
@@ -1344,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1387,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ // NB: we're lying about transport protocol here to verify the raw
+ // fragment header bytes.
+ TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1422,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1457,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index 0134fadc0..da88d65d1 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -16,7 +16,6 @@ package ipv4
import (
"fmt"
- "sync"
"sync/atomic"
"time"
@@ -58,6 +57,9 @@ type IGMPOptions struct {
// When enabled, IGMP may transmit IGMP report and leave messages when
// joining and leaving multicast groups respectively, and handle incoming
// IGMP packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
Enabled bool
}
@@ -68,8 +70,9 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
// igmpState.init() MUST be called after creating an IGMP state.
type igmpState struct {
// The IPv4 endpoint this igmpState is for.
- ep *endpoint
- opts IGMPOptions
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
// igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
// RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
@@ -84,20 +87,23 @@ type igmpState struct {
// when false.
igmpV1Present uint32
- mu struct {
- sync.RWMutex
-
- genericMulticastProtocol ip.GenericMulticastProtocolState
+ // igmpV1Job is scheduled when this interface receives an IGMPv1 style
+ // message, upon expiration the igmpV1Present flag is cleared.
+ // igmpV1Job may not be nil once igmpState is initialized.
+ igmpV1Job *tcpip.Job
+}
- // igmpV1Job is scheduled when this interface receives an IGMPv1 style
- // message, upon expiration the igmpV1Present flag is cleared.
- // igmpV1Job may not be nil once igmpState is initialized.
- igmpV1Job *tcpip.Job
- }
+// Enabled implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) Enabled() bool {
+ // No need to perform IGMP on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
}
// SendReport implements ip.MulticastGroupProtocol.
-func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
igmpType := header.IGMPv2MembershipReport
if igmp.v1Present() {
igmpType = header.IGMPv1MembershipReport
@@ -106,6 +112,8 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
}
// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// As per RFC 2236 Section 6, Page 8: "If the interface state says the
// Querier is running IGMPv1, this action SHOULD be skipped. If the flag
@@ -114,18 +122,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
if igmp.v1Present() {
return nil
}
- return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ return err
}
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
-func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
+//
+// Must only be called once for the lifetime of igmp.
+func (igmp *igmpState) init(ep *endpoint) {
igmp.ep = ep
- igmp.opts = opts
- igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
- Enabled: opts.Enabled,
+ igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
@@ -133,11 +140,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) {
AllNodesAddress: header.IPv4AllSystems,
})
igmp.igmpV1Present = igmpV1PresentDefault
- igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() {
+ igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
igmp.setV1Present(false)
})
}
+// handleIGMP handles an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
stats := igmp.ep.protocol.stack.Stats()
received := stats.IGMP.PacketsReceived
@@ -207,32 +217,34 @@ func (igmp *igmpState) setV1Present(v bool) {
}
}
+// handleMembershipQuery handles a membership query.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
-
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
- if maxRespTime == 0 && igmp.opts.Enabled {
- igmp.mu.igmpV1Job.Cancel()
- igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout)
+ if maxRespTime == 0 && igmp.Enabled() {
+ igmp.igmpV1Job.Cancel()
+ igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.setV1Present(true)
maxRespTime = v1MaxRespTime
}
- igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime)
+ igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime)
}
+// handleMembershipReport handles a membership report.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.HandleReport(groupAddress)
+ igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
}
-// writePacket assembles and sends an IGMP packet with the provided fields,
-// incrementing the provided stat counter on success.
-func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error {
+// writePacket assembles and sends an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
igmpData.SetType(igmpType)
igmpData.SetGroupAddress(groupAddress)
@@ -243,9 +255,13 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
Data: buffer.View(igmpData).ToVectorisedView(),
})
- // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
- // rather we should select an appropriate local address.
- localAddr := header.IPv4Any
+ addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return false, nil
+ }
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
+ addressEndpoint = nil
igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.IGMPProtocolNumber,
TTL: header.IGMPTTL,
@@ -254,22 +270,22 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
&header.IPv4SerializableRouterAlertOption{},
})
- sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.Dropped.Increment()
- return err
+ sentStats.Dropped.Increment()
+ return false, err
}
switch igmpType {
case header.IGMPv1MembershipReport:
- sent.V1MembershipReport.Increment()
+ sentStats.V1MembershipReport.Increment()
case header.IGMPv2MembershipReport:
- sent.V2MembershipReport.Increment()
+ sentStats.V2MembershipReport.Increment()
case header.IGMPLeaveGroup:
- sent.LeaveGroup.Increment()
+ sentStats.LeaveGroup.Increment()
default:
panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
}
- return nil
+ return true, nil
}
// joinGroup handles adding a new group to the membership map, setting up the
@@ -278,28 +294,27 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
//
// If the group already exists in the membership map, returns
// tcpip.ErrDuplicateAddress.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */)
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: igmp.ep.mu must be read locked.
func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress)
+ return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Leave Group message
// if required.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
-
// LeaveGroup returns false only if the group was not joined.
- if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
@@ -308,16 +323,23 @@ func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
// softLeaveAll leaves all groups from the perspective of IGMP, but remains
// joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) softLeaveAll() {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.MakeAllNonMember()
+ igmp.genericMulticastProtocol.MakeAllNonMemberLocked()
}
// initializeAll attemps to initialize the IGMP state for each group that has
// been joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) initializeAll() {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.InitializeGroups()
+ igmp.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) sendQueuedReports() {
+ igmp.genericMulticastProtocol.SendQueuedReportsLocked()
}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 5e139377b..1ee573ac8 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -16,6 +16,7 @@ package ipv4_test
import (
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -29,6 +30,7 @@ import (
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ addr = tcpip.Address("\x0a\x00\x00\x01")
multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
nicID = 1
)
@@ -41,6 +43,7 @@ func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
checker.IPv4(t, payload,
+ checker.SrcAddr(addr),
checker.DstAddr(remoteAddress),
// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
checker.TTL(1),
@@ -71,7 +74,6 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
-
return e, s, clock
}
@@ -104,6 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// reports for backwards compatibility.
func TestIgmpV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
@@ -154,3 +159,57 @@ func TestIgmpV1Present(t *testing.T) {
}
validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
}
+
+func TestSendQueuedIGMPReports(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Joining a group without an assigned address should queue IGMP packets; none
+ // should be sent without an assigned address.
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
+ }
+ reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
+ if got := reportStat.Value(); got != 0 {
+ t.Errorf("got reportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+
+ // The initial set of IGMP reports that were queued should be sent once an
+ // address is assigned.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+ if got := reportStat.Value(); got != 1 {
+ t.Errorf("got reportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ if got := reportStat.Value(); got != 2 {
+ t.Errorf("got reportStat.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should have no more packets to send after the initial set of unsolicited
+ // reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 3076185cd..e9ff70d04 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -72,7 +72,6 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
- igmp igmpState
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -84,6 +83,7 @@ type endpoint struct {
sync.RWMutex
addressableEndpointState stack.AddressableEndpointState
+ igmp igmpState
}
}
@@ -94,8 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.igmp.init(e, p.options.IGMP)
+ e.mu.igmp.init(e)
+ e.mu.Unlock()
return e
}
@@ -127,7 +129,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// endpoint may have left groups from the perspective of IGMP when the
// endpoint was disabled. Either way, we need to let routers know to
// send us multicast traffic.
- e.igmp.initializeAll()
+ e.mu.igmp.initializeAll()
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
@@ -170,7 +172,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.isEnabled() {
return
}
@@ -181,12 +183,16 @@ func (e *endpoint) disableLocked() {
// Leave groups from the perspective of IGMP so that routers know that
// we are no longer interested in the group.
- e.igmp.softLeaveAll()
+ e.mu.igmp.softLeaveAll()
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -718,7 +724,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
if p == header.IGMPProtocolNumber {
- e.igmp.handleIGMP(pkt)
+ e.mu.Lock()
+ e.mu.igmp.handleIGMP(pkt)
+ e.mu.Unlock()
return
}
if opts := h.Options(); len(opts) != 0 {
@@ -776,7 +784,12 @@ func (e *endpoint) Close() {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err == nil {
+ e.mu.igmp.sendQueuedReports()
+ }
+ return ep, err
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
@@ -811,6 +824,14 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
@@ -843,7 +864,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadAddress
}
- e.igmp.joinGroup(addr)
+ e.mu.igmp.joinGroup(addr)
return nil
}
@@ -858,14 +879,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
//
// Precondition: e.mu must be locked.
func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
- return e.igmp.leaveGroup(addr)
+ return e.mu.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.igmp.isInGroup(addr)
+ return e.mu.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 9e2d2cfd6..ef62fe6fc 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -2669,8 +2669,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2712,8 +2712,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != header.IPv4ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
}
- if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
}
checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
@@ -2761,8 +2761,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != arp.ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
}
- if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress)
+ if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
}
rep := header.ARP(p.Pkt.NetworkHeader().View())
if got := rep.Op(); got != header.ARPRequest {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 5e75c8740..afa45aefe 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -58,7 +58,10 @@ go_test(
srcs = ["mld_test.go"],
deps = [
":ipv6",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 510276b8e..6ee162713 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -645,26 +645,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
- var handler func(header.MLD)
switch icmpType {
case header.ICMPv6MulticastListenerQuery:
received.MulticastListenerQuery.Increment()
- handler = e.mld.handleMulticastListenerQuery
case header.ICMPv6MulticastListenerReport:
received.MulticastListenerReport.Increment()
- handler = e.mld.handleMulticastListenerReport
case header.ICMPv6MulticastListenerDone:
received.MulticastListenerDone.Increment()
default:
panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
}
+
if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
received.Invalid.Increment()
return
}
- if handler != nil {
- handler(header.MLD(payload.ToView()))
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerReport:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerDone:
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
}
default:
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 32adb5c83..34a6a8446 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -149,9 +149,8 @@ func (*testInterface) Promiscuous() bool {
}
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- r := stack.Route{
- NetProto: protocol,
- }
+ var r stack.Route
+ r.NetProto = protocol
r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -296,11 +295,11 @@ func TestICMPCounts(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -454,11 +453,11 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
ep.HandlePacket(pkt)
}
@@ -600,8 +599,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
return
}
- if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr {
- t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr)
+ if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr {
+ t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
}
// Pull the full payload since network header. Needed for header.IPv6 to
@@ -853,11 +852,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
@@ -930,11 +929,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1048,11 +1047,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(icmpSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1108,11 +1107,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1227,11 +1226,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(size + payloadSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
@@ -1381,8 +1380,8 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
- if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr)
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
}
if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
@@ -1445,11 +1444,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1463,8 +1462,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1487,11 +1486,11 @@ func TestPacketQueing(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1505,8 +1504,8 @@ func TestPacketQueing(t *testing.T) {
if p.Proto != ProtocolNumber {
t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
- if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr)
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1556,8 +1555,8 @@ func TestPacketQueing(t *testing.T) {
t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
}
snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
- if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
@@ -1586,11 +1585,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1828,11 +1827,11 @@ func TestCallsToNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.source,
- DstAddr: test.destination,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
ep.HandlePacket(pkt)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 8bf84601f..f2018d073 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -19,6 +19,7 @@ import (
"encoding/binary"
"fmt"
"hash/fnv"
+ "math"
"sort"
"sync/atomic"
"time"
@@ -60,6 +61,108 @@ const (
buckets = 2048
)
+// policyTable is the default policy table defined in RFC 6724 section 2.1.
+//
+// A more human-readable version:
+//
+// Prefix Precedence Label
+// ::1/128 50 0
+// ::/0 40 1
+// ::ffff:0:0/96 35 4
+// 2002::/16 30 2
+// 2001::/32 5 5
+// fc00::/7 3 13
+// ::/96 1 3
+// fec0::/10 1 11
+// 3ffe::/16 1 12
+//
+// The table is sorted by prefix length so longest-prefix match can be easily
+// achieved.
+//
+// We willingly left out ::/96, fec0::/10 and 3ffe::/16 since those prefix
+// assignments are deprecated.
+//
+// As per RFC 4291 section 2.5.5.1 (for ::/96),
+//
+// The "IPv4-Compatible IPv6 address" is now deprecated because the
+// current IPv6 transition mechanisms no longer use these addresses.
+// New or updated implementations are not required to support this
+// address type.
+//
+// As per RFC 3879 section 4 (for fec0::/10),
+//
+// This document formally deprecates the IPv6 site-local unicast prefix
+// defined in [RFC3513], i.e., 1111111011 binary or FEC0::/10.
+//
+// As per RFC 3701 section 1 (for 3ffe::/16),
+//
+// As clearly stated in [TEST-NEW], the addresses for the 6bone are
+// temporary and will be reclaimed in the future. It further states
+// that all users of these addresses (within the 3FFE::/16 prefix) will
+// be required to renumber at some time in the future.
+//
+// and section 2,
+//
+// Thus after the pTLA allocation cutoff date January 1, 2004, it is
+// REQUIRED that no new 6bone 3FFE pTLAs be allocated.
+//
+// MUST NOT BE MODIFIED.
+var policyTable = [...]struct {
+ subnet tcpip.Subnet
+
+ label uint8
+}{
+ // ::1/128
+ {
+ subnet: header.IPv6Loopback.WithPrefix().Subnet(),
+ label: 0,
+ },
+ // ::ffff:0:0/96
+ {
+ subnet: header.IPv4MappedIPv6Subnet,
+ label: 4,
+ },
+ // 2001::/32 (Teredo prefix as per RFC 4380 section 2.6).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 32,
+ }.Subnet(),
+ label: 5,
+ },
+ // 2002::/16 (6to4 prefix as per RFC 3056 section 2).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 16,
+ }.Subnet(),
+ label: 2,
+ },
+ // fc00::/7 (Unique local addresses as per RFC 4193 section 3.1).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 7,
+ }.Subnet(),
+ label: 13,
+ },
+ // ::/0
+ {
+ subnet: header.IPv6EmptySubnet,
+ label: 1,
+ },
+}
+
+func getLabel(addr tcpip.Address) uint8 {
+ for _, p := range policyTable {
+ if p.subnet.Contains(addr) {
+ return p.label
+ }
+ }
+
+ panic(fmt.Sprintf("should have a label for address = %s", addr))
+}
+
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -85,9 +188,8 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
+ mld mldState
}
-
- mld mldState
}
// NICNameFromID is a function that returns a stable name for the specified NIC,
@@ -122,6 +224,45 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// onAddressAssignedLocked handles an address being assigned.
+//
+// Precondition: e.mu must be exclusively locked.
+func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, ...
+ //
+ // If we just completed DAD for a link-local address, then attempt to send any
+ // queued MLD reports. Note, we may have sent reports already for some of the
+ // groups before we had a valid link-local address to use as the source for
+ // the MLD messages, but that was only so that MLD snooping switches are aware
+ // of our membership to groups - routers would not have handled those reports.
+ //
+ // As per RFC 3590 section 4,
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ if header.IsV6LinkLocalAddress(addr) {
+ e.mu.mld.sendQueuedReports()
+ }
+}
+
// InvalidateDefaultRouter implements stack.NDPEndpoint.
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
@@ -232,7 +373,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// endpoint may have left groups from the perspective of MLD when the
// endpoint was disabled. Either way, we need to let routers know to
// send us multicast traffic.
- e.mld.initializeAll()
+ e.mu.mld.initializeAll()
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
@@ -334,7 +475,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.Enabled() {
return
}
@@ -349,7 +490,11 @@ func (e *endpoint) disableLocked() {
// Leave groups from the perspective of MLD so that routers know that
// we are no longer interested in the group.
- e.mld.softLeaveAll()
+ e.mu.mld.softLeaveAll()
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -389,19 +534,27 @@ func (e *endpoint) MTU() uint32 {
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
+ // TODO(gvisor.dev/issues/5035): The maximum header length returned here does
+ // not open the possibility for the caller to know about size required for
+ // extension headers.
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) {
+ extHdrsLen := extensionHeaders.Length()
+ length := pkt.Size() + extensionHeaders.Length()
+ if length > math.MaxUint16 {
+ panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16))
+ }
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(params.Protocol),
- HopLimit: params.TTL,
- TrafficClass: params.TOS,
- SrcAddr: srcAddr,
- DstAddr: dstAddr,
+ PayloadLength: uint16(length),
+ TransportProtocol: params.Protocol,
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ ExtensionHeaders: extensionHeaders,
})
pkt.NetworkProtocolNumber = ProtocolNumber
}
@@ -456,7 +609,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */)
// iptables filtering. All packets that reach here are locally
// generated.
@@ -545,7 +698,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */)
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
@@ -1177,13 +1330,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
return addressEndpoint, nil
}
- snmc := header.SolicitedNodeAddr(addr.Address)
- if err := e.joinGroupLocked(snmc); err != nil {
- // joinGroupLocked only returns an error if the group address is not a valid
- // IPv6 multicast address.
- panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
- }
-
addressEndpoint.SetKind(stack.PermanentTentative)
if e.Enabled() {
@@ -1192,6 +1338,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
}
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
+ }
+
return addressEndpoint, nil
}
@@ -1293,6 +1446,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow
return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
+// getLinkLocalAddressRLocked returns a link-local address from the primary list
+// of addresses, if one is available.
+//
+// See stack.PrimaryEndpointBehavior for more details about the primary list.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
+ var linkLocalAddr tcpip.Address
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.IsAssigned(false /* allowExpired */) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ linkLocalAddr = addr
+ return false
+ }
+ }
+ return true
+ })
+ return linkLocalAddr
+}
+
// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
// but with locking requirements.
//
@@ -1302,7 +1475,11 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// RFC 6724 section 5.
type addrCandidate struct {
addressEndpoint stack.AddressEndpoint
+ addr tcpip.Address
scope header.IPv6AddressScope
+
+ label uint8
+ matchingPrefix uint8
}
if len(remoteAddr) == 0 {
@@ -1312,10 +1489,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// Create a candidate set of available addresses we can potentially use as a
// source address.
var cs []addrCandidate
- e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
// If r is not valid for outgoing connections, it is not a valid endpoint.
if !addressEndpoint.IsAssigned(allowExpired) {
- return
+ return true
}
addr := addressEndpoint.AddressWithPrefix().Address
@@ -1329,8 +1506,13 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
cs = append(cs, addrCandidate{
addressEndpoint: addressEndpoint,
+ addr: addr,
scope: scope,
+ label: getLabel(addr),
+ matchingPrefix: remoteAddr.MatchingPrefix(addr),
})
+
+ return true
})
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
@@ -1339,18 +1521,20 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
}
+ remoteLabel := getLabel(remoteAddr)
+
// Sort the addresses as per RFC 6724 section 5 rules 1-3.
//
- // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ // TODO(b/146021396): Implement rules 4, 5 of RFC 6724 section 5.
sort.Slice(cs, func(i, j int) bool {
sa := cs[i]
sb := cs[j]
// Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sa.addr == remoteAddr {
return true
}
- if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sb.addr == remoteAddr {
return false
}
@@ -1367,11 +1551,29 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
return sbDep
}
+ // Prefer matching label as per RFC 6724 section 5 rule 6.
+ if sa, sb := sa.label == remoteLabel, sb.label == remoteLabel; sa != sb {
+ if sa {
+ return true
+ }
+ if sb {
+ return false
+ }
+ }
+
// Prefer temporary addresses as per RFC 6724 section 5 rule 7.
if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
return saTemp
}
+ // Use longest matching prefix as per RFC 6724 section 5 rule 8.
+ if sa.matchingPrefix > sb.matchingPrefix {
+ return true
+ }
+ if sb.matchingPrefix > sa.matchingPrefix {
+ return false
+ }
+
// sa and sb are equal, return the endpoint that is closest to the front of
// the primary endpoint list.
return i < j
@@ -1417,7 +1619,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadAddress
}
- e.mld.joinGroup(addr)
+ e.mu.mld.joinGroup(addr)
return nil
}
@@ -1432,14 +1634,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
//
// Precondition: e.mu must be locked.
func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
- return e.mld.leaveGroup(addr)
+ return e.mu.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mld.isInGroup(addr)
+ return e.mu.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -1504,17 +1706,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp = ndpState{
- ep: e,
- configs: p.options.NDPConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- e.mu.ndp.initializeTempAddrState()
- e.mld.init(e, p.options.MLD)
+ e.mu.ndp.init(e)
+ e.mu.mld.init(e)
+ e.mu.Unlock()
p.mu.Lock()
defer p.mu.Unlock()
@@ -1735,24 +1931,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
fragPkt.NetworkProtocolNumber = ProtocolNumber
originalIPHeadersLength := len(originalIPHeaders)
- fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+
+ s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ M: more,
+ Identification: id,
+ }}
+
+ fragmentIPHeadersLength := originalIPHeadersLength + s.Length()
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
- fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
}
- fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
- fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
- fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
- fragmentHeader.Encode(&header.IPv6FragmentFields{
- M: more,
- FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
- Identification: id,
- NextHeader: uint8(transportProto),
- })
+ nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:])
+
+ fragmentIPHeaders.SetNextHeader(nextHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
return fragPkt, more
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1c01f17ab..5f07d3af8 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -69,11 +69,11 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -127,11 +127,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -915,10 +915,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
- NextHeader: ipv6NextHdr,
- HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: dstAddr,
+ // We're lying about transport protocol here to be able to generate
+ // raw extension headers from the test definitions.
+ TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr),
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1947,10 +1949,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
+ // We're lying about transport protocol here so that we can generate
+ // raw extension headers for the tests.
+ TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr),
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
@@ -1995,7 +1999,7 @@ func TestInvalidIPv6Fragments(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -2014,14 +2018,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 9,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0 >> 3,
M: true,
Identification: ident,
@@ -2041,14 +2044,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
M: false,
Identification: ident,
@@ -2089,10 +2091,9 @@ func TestInvalidIPv6Fragments(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
-
- fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2154,7 +2155,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -2168,14 +2169,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2190,14 +2190,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2206,14 +2205,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2228,14 +2226,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2250,14 +2247,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2266,14 +2262,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2288,14 +2283,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2304,14 +2298,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2350,10 +2343,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2994,11 +2988,11 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: test.TTL,
- SrcAddr: remoteIPv6Addr1,
- DstAddr: remoteIPv6Addr2,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
})
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index 4c06b3f0c..e8d1e7a79 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -40,6 +40,9 @@ type MLDOptions struct {
// When enabled, MLD may transmit MLD report and done messages when
// joining and leaving multicast groups respectively, and handle incoming
// MLD packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
Enabled bool
}
@@ -55,22 +58,35 @@ type mldState struct {
genericMulticastProtocol ip.GenericMulticastProtocolState
}
+// Enabled implements ip.MulticastGroupProtocol.
+func (mld *mldState) Enabled() bool {
+ // No need to perform MLD on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
+}
+
// SendReport implements ip.MulticastGroupProtocol.
-func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
}
// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
- return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ return err
}
// init sets up an mldState struct, and is required to be called before using
// a new mldState.
-func (mld *mldState) init(ep *endpoint, opts MLDOptions) {
+//
+// Must only be called once for the lifetime of mld.
+func (mld *mldState) init(ep *endpoint) {
mld.ep = ep
- mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
- Enabled: opts.Enabled,
+ mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
@@ -79,33 +95,45 @@ func (mld *mldState) init(ep *endpoint, opts MLDOptions) {
})
}
+// handleMulticastListenerQuery handles a query message.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
- mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+ mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
}
+// handleMulticastListenerReport handles a report message.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
- mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress())
+ mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress())
}
// joinGroup handles joining a new group and sending and scheduling the required
// messages.
//
// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
- mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */)
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
}
// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
- return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress)
+ return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Done message, if
// required.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
- if mld.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
@@ -114,17 +142,31 @@ func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
// softLeaveAll leaves all groups from the perspective of MLD, but remains
// joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) softLeaveAll() {
- mld.genericMulticastProtocol.MakeAllNonMember()
+ mld.genericMulticastProtocol.MakeAllNonMemberLocked()
}
// initializeAll attemps to initialize the MLD state for each group that has
// been joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) initializeAll() {
- mld.genericMulticastProtocol.InitializeGroups()
+ mld.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) sendQueuedReports() {
+ mld.genericMulticastProtocol.SendQueuedReportsLocked()
}
-func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error {
+// writePacket assembles and sends an MLD packet.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
var mldStat *tcpip.StatCounter
switch mldType {
@@ -139,26 +181,82 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
icmp.SetType(mldType)
header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
- // TODO(gvisor.dev/issue/4888): We should not use the unspecified address,
- // rather we should select an appropriate local address.
- localAddress := header.IPv6Any
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert
+ // option in a Hop-by-Hop Options header.
+ //
+ // However, this would cause problems with Duplicate Address Detection with
+ // the first address as MLD snooping switches may not send multicast traffic
+ // that DAD depends on to the node performing DAD without the MLD report, as
+ // documented in RFC 4816:
+ //
+ // Note that when a node joins a multicast address, it typically sends a
+ // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810]
+ // for the multicast address. In the case of Duplicate Address
+ // Detection, the MLD report message is required in order to inform MLD-
+ // snooping switches, rather than routers, to forward multicast packets.
+ // In the above description, the delay for joining the multicast address
+ // thus means delaying transmission of the corresponding MLD report
+ // message. Since the MLD specifications do not request a random delay
+ // to avoid race conditions, just delaying Neighbor Solicitation would
+ // cause congestion by the MLD report messages. The congestion would
+ // then prevent the MLD-snooping switches from working correctly and, as
+ // a result, prevent Duplicate Address Detection from working. The
+ // requirement to include the delay for the MLD report in this case
+ // avoids this scenario. [RFC3590] also talks about some interaction
+ // issues between Duplicate Address Detection and MLD, and specifies
+ // which source address should be used for the MLD report in this case.
+ //
+ // As per RFC 3590 section 4, we should still send out MLD reports with an
+ // unspecified source address if we do not have an assigned link-local
+ // address to use as the source address to ensure DAD works as expected on
+ // networks with MLD snooping switches:
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ localAddress := mld.ep.getLinkLocalAddressRLocked()
+ if len(localAddress) == 0 {
+ localAddress = header.IPv6Any
+ }
+
icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+ extensionHeaders := header.IPv6ExtHdrSerializer{
+ header.IPv6SerializableHopByHopExtHdr{
+ &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
+ },
+ }
+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()),
+ ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(),
Data: buffer.View(icmp).ToVectorisedView(),
})
mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.MLDHopLimit,
- })
- // TODO(b/162198658): set the ROUTER_ALERT option when sending Host
- // Membership Reports.
+ }, extensionHeaders)
if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sentStats.Dropped.Increment()
- return err
+ return false, err
}
mldStat.Increment()
- return nil
+ return localAddress != header.IPv6Any, nil
}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 5677bdd54..e2778b656 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -16,8 +16,12 @@ package ipv6_test
import (
"testing"
+ "time"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -25,9 +29,34 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
+var (
+ linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
+ globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
+)
+
+func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
+ t.Helper()
+
+ checker.IPv6WithExtHdr(t, p,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(localAddress),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(mldType, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
const nicID = 1
@@ -46,45 +75,223 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
}
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected a report message to be sent")
- }
- snmc := header.SolicitedNodeAddr(addr1)
- checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
- checker.DstAddr(snmc),
- // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
- checker.TTL(1),
- checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize,
- checker.MLDMaxRespDelay(0),
- checker.MLDMulticastAddress(snmc),
- ),
- )
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
}
// The stack will leave an address's solicited node multicast address when
// an address is removed. An MLD done message should be sent for the
// solicited-node group.
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a done message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
}
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected a done message to be sent")
- }
- snmc := header.SolicitedNodeAddr(addr1)
- checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(1),
- checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize,
- checker.MLDMaxRespDelay(0),
- checker.MLDMulticastAddress(snmc),
- ),
- )
+}
+
+func TestSendQueuedMLDReports(t *testing.T) {
+ const (
+ nicID = 1
+ maxReports = 2
+ )
+
+ tests := []struct {
+ name string
+ dadTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD Disabled",
+ dadTransmits: 0,
+ retransmitTimer: 0,
+ },
+ {
+ name: "DAD Enabled",
+ dadTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dadTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ Clock: clock,
+ })
+
+ // Allow space for an extra packet so we can observe packets that were
+ // unexpectedly sent.
+ e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ resolveDAD := func(addr, snmc tcpip.Address) {
+ clock.Advance(dadResolutionTime)
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected DAD packet")
+ } else {
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr),
+ checker.NDPNSOptions(nil),
+ ))
+ }
+ }
+
+ var reportCounter uint64
+ reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ var doneCounter uint64
+ doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+
+ // Joining a group without an assigned address should send an MLD report
+ // with the unspecified address.
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalMulticastAddr)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a global address should not send reports for the already joined
+ // group since we should only send queued reports when a link-local
+ // addres sis assigned.
+ //
+ // Note, we will still expect to send a report for the global address's
+ // solicited node address from the unspecified address as per RFC 3590
+ // section 4.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
+ }
+ if dadResolutionTime != 0 {
+ // Reports should not be sent when the address resolves.
+ resolveDAD(globalAddr, globalAddrSNMC)
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ }
+ // Leave the group since we don't care about the global address's
+ // solicited node multicast group membership.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
+ }
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a link-local address should send a report for its solicited node
+ // address and globalMulticastAddr.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ }
+ if dadResolutionTime != 0 {
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+ resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
+ }
+
+ // We expect two batches of reports to be sent (1 batch when the
+ // link-local address is assigned, and another after the maximum
+ // unsolicited report interval.
+ for i := 0; i < 2; i++ {
+ // We expect reports to be sent (one for globalMulticastAddr and another
+ // for linkLocalAddrSNMC).
+ reportCounter += maxReports
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+
+ addrs := map[tcpip.Address]bool{
+ globalMulticastAddr: false,
+ linkLocalAddrSNMC: false,
+ }
+ for _ = range addrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
+ }
+
+ addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
+ if seen, ok := addrs[addr]; !ok {
+ t.Fatalf("got unexpected packet destined to %s", addr)
+ } else if seen {
+ t.Fatalf("got another packet destined to %s", addr)
+ }
+
+ addrs[addr] = true
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
+
+ clock.Advance(ipv6.UnsolicitedReportIntervalMax)
+ }
+ }
+
+ // Should not send any more reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 8cb7d4dab..d515eb622 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -20,6 +20,7 @@ import (
"math/rand"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
// The IPv6 endpoint this ndpState is for.
ep *endpoint
@@ -643,6 +647,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
+ ndp.ep.onAddressAssignedLocked(addr)
return nil
}
@@ -686,12 +691,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
- // If DAD resolved for a stable SLAAC address, attempt generation of a
- // temporary SLAAC address.
- if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the generation
- // of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ if dadDone {
+ if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the
+ // generation of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
+
+ ndp.ep.onAddressAssignedLocked(addr)
}
}),
}
@@ -728,7 +735,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- })
+ }, nil /* extensionHeaders */)
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
@@ -1850,7 +1857,7 @@ func (ndp *ndpState) startSolicitingRouters() {
ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- })
+ }, nil /* extensionHeaders */)
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
@@ -1884,11 +1891,19 @@ func (ndp *ndpState) stopSolicitingRouters() {
ndp.rtrSolicitJob = nil
}
-// initializeTempAddrState initializes state related to temporary SLAAC
-// addresses.
-func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
+func (ndp *ndpState) init(ep *endpoint) {
+ if ndp.dad != nil {
+ panic("attempted to initialize NDP state twice")
+ }
+ ndp.ep = ep
+ ndp.configs = ep.protocol.options.NDPConfigs
+ ndp.dad = make(map[tcpip.Address]dadState)
+ ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
+ ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
+
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 95c626bb8..7ddb19c00 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -213,11 +213,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
@@ -319,11 +319,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
@@ -599,11 +599,11 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
})
invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
@@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != respNSDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
}
- if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want)
+ if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -681,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
})
e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) {
if p.Route.RemoteAddress != test.naDst {
t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
}
- if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr)
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -785,11 +785,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
@@ -898,11 +898,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
@@ -979,29 +979,25 @@ func TestNDPValidation(t *testing.T) {
}
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
+ var extHdrs header.IPv6ExtHdrSerializer
if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
}
+ extHdrsLen := extHdrs.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen,
Data: payload.ToVectorisedView(),
})
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(payload) + extHdrsLen),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: hopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ ExtensionHeaders: extHdrs,
})
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
ep.HandlePacket(pkt)
}
@@ -1351,11 +1347,11 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
stats := s.Stats().ICMP.V6.PacketsReceived
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 95fb67986..05d98a0a5 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -34,6 +35,9 @@ import (
const (
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ ipv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+
ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
@@ -48,6 +52,8 @@ const (
mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
mldReport = uint8(header.ICMPv6MulticastListenerReport)
mldDone = uint8(header.ICMPv6MulticastListenerDone)
+
+ maxUnsolicitedReports = 2
)
var (
@@ -61,6 +67,8 @@ var (
}
return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
}()
+
+ ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr)
)
// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
@@ -69,7 +77,11 @@ func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.A
t.Helper()
payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
- checker.IPv6(t, payload,
+ checker.IPv6WithExtHdr(t, payload,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(ipv6Addr),
checker.DstAddr(remoteAddress),
// Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
checker.TTL(1),
@@ -87,6 +99,7 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
checker.IPv4(t, payload,
+ checker.SrcAddr(ipv4Addr),
checker.DstAddr(remoteAddress),
// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
checker.TTL(1),
@@ -99,23 +112,31 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.
)
}
-func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
t.Helper()
- // Create an endpoint of queue size 2, since no more than 2 packets are ever
- // queued in the tests in this file.
- e := channel.New(2, 1280, linkAddr)
+ e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
+ s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
+ return e, s, clock
+}
+
+func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ igmpEnabled := v4 && mgpEnabled
+ mldEnabled := !v4 && mgpEnabled
+
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocolWithOptions(ipv4.Options{
IGMP: ipv4.IGMPOptions{
- Enabled: mgpEnabled,
+ Enabled: igmpEnabled,
},
}),
ipv6.NewProtocolWithOptions(ipv6.Options{
MLD: ipv6.MLDOptions{
- Enabled: mgpEnabled,
+ Enabled: mldEnabled,
},
}),
},
@@ -124,8 +145,59 @@ func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err)
+ }
- return e, s, clock
+ return s, clock
+}
+
+// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
+// when it is created with an IPv6 address.
+//
+// To not interfere with tests, checkInitialIPv6Groups will leave the added
+// address's solicited node multicast group so that the tests can all assume
+// the NIC has not joined any IPv6 groups.
+func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
+ t.Helper()
+
+ stats := s.Stats().ICMP.V6.PacketsSent
+
+ reportCounter++
+ if got := stats.MulticastListenerReport.Value(); got != reportCounter {
+ t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
+ }
+
+ // Leave the group to not affect the tests. This is fine since we are not
+ // testing DAD or the solicited node address specifically.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
+ }
+ leaveCounter++
+ if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
+ t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ return reportCounter, leaveCounter
}
// createAndInjectIGMPPacket creates and injects an IGMP packet with the
@@ -170,11 +242,11 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay b
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- HopLimit: header.MLDHopLimit,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- SrcAddr: header.IPv4Any,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(icmpSize),
+ HopLimit: header.MLDHopLimit,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
icmp := header.ICMPv6(buf[header.IPv6MinimumSize:])
@@ -232,13 +304,13 @@ func TestMGPDisabled(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, false)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
// This NIC may join multicast groups when it is enabled but since MGP is
// disabled, no reports should be sent.
sentReportStat := test.sentReportStat(s)
if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
clock.Advance(time.Hour)
if p, ok := e.Read(); ok {
@@ -251,7 +323,7 @@ func TestMGPDisabled(t *testing.T) {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportState.Value() = %d, want = 0", got)
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
}
clock.Advance(time.Hour)
if p, ok := e.Read(); ok {
@@ -355,7 +427,7 @@ func TestMGPReceiveCounters(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, true)
+ e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
if got := test.statCounter(s).Value(); got != 1 {
@@ -376,6 +448,7 @@ func TestMGPJoinGroup(t *testing.T) {
sentReportStat func(*stack.Stack) *tcpip.StatCounter
receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
validateReport func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -410,21 +483,28 @@ func TestMGPJoinGroup(t *testing.T) {
validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
// Test joining a specific address explicitly and verify a Report is sent
// immediately.
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
+ reportCounter++
sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 1 {
- t.Errorf("got sentReportState.Value() = %d, want = 1", got)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -442,8 +522,9 @@ func TestMGPJoinGroup(t *testing.T) {
t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
}
clock.Advance(test.maxUnsolicitedResponseDelay)
- if got := sentReportStat.Value(); got != 2 {
- t.Errorf("got sentReportState.Value() = %d, want = 2", got)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -464,13 +545,14 @@ func TestMGPJoinGroup(t *testing.T) {
// group the stack sends a leave/done message.
func TestMGPLeaveGroup(t *testing.T) {
tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo)
- validateLeave func(*testing.T, channel.PacketInfo)
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -513,18 +595,26 @@ func TestMGPLeaveGroup(t *testing.T) {
validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
- if got := test.sentReportStat(s).Value(); got != 1 {
- t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got)
+ reportCounter++
+ if got := test.sentReportStat(s).Value(); got != reportCounter {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -539,8 +629,9 @@ func TestMGPLeaveGroup(t *testing.T) {
if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
}
- if got := test.sentLeaveStat(s).Value(); got != 1 {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got)
+ leaveCounter++
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a leave message to be sent")
@@ -570,6 +661,7 @@ func TestMGPQueryMessages(t *testing.T) {
rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
validateReport func(*testing.T, channel.PacketInfo)
maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -614,6 +706,7 @@ func TestMGPQueryMessages(t *testing.T) {
maxRespTimeToDuration: func(d uint8) time.Duration {
return time.Duration(d) * time.Millisecond
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
@@ -647,16 +740,22 @@ func TestMGPQueryMessages(t *testing.T) {
for _, subTest := range subTests {
t.Run(subTest.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
sentReportStat := test.sentReportStat(s)
- for i := uint64(1); i <= 2; i++ {
+ for i := 0; i < maxUnsolicitedReports; i++ {
sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != i {
- t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatalf("expected %d-th report message to be sent", i)
@@ -686,8 +785,9 @@ func TestMGPQueryMessages(t *testing.T) {
if subTest.expectReport {
clock.Advance(test.maxRespTimeToDuration(maxRespTime))
- if got := sentReportStat.Value(); got != 3 {
- t.Errorf("got sentReportState.Value() = %d, want = 3", got)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -719,6 +819,7 @@ func TestMGPReportMessages(t *testing.T) {
rxReport func(*channel.Endpoint)
validateReport func(*testing.T, channel.PacketInfo)
maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -761,19 +862,27 @@ func TestMGPReportMessages(t *testing.T) {
maxRespTimeToDuration: func(d uint8) time.Duration {
return time.Duration(d) * time.Millisecond
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
}
sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 1 {
- t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -788,8 +897,8 @@ func TestMGPReportMessages(t *testing.T) {
// reports.
test.rxReport(e)
clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != 1 {
- t.Errorf("got sentReportStat.Value() = %d, want = 1", got)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); ok {
t.Errorf("sent unexpected packet = %#v", p)
@@ -804,8 +913,8 @@ func TestMGPReportMessages(t *testing.T) {
t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
}
clock.Advance(time.Hour)
- if got := test.sentLeaveStat(s).Value(); got != 0 {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got)
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
}
// Should not send any more packets.
@@ -829,6 +938,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
}{
{
name: "IGMP",
@@ -897,10 +1007,31 @@ func TestMGPWithNICLifecycle(t *testing.T) {
t.Helper()
ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
- if got := tcpip.TransportProtocolNumber(ipv6.NextHeader()); got != header.ICMPv6ProtocolNumber {
+
+ ipv6HeaderIter := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var transport header.IPv6RawPayloadHeader
+ for {
+ h, done, err := ipv6HeaderIter.Next()
+ if err != nil {
+ t.Fatalf("ipv6HeaderIter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ }
+ if t, ok := h.(header.IPv6RawPayloadHeader); ok {
+ transport = t
+ break
+ }
+ }
+
+ if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
}
- icmpv6 := header.ICMPv6(ipv6.Payload())
+ icmpv6 := header.ICMPv6(transport.Buf.ToView())
if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
}
@@ -914,17 +1045,22 @@ func TestMGPWithNICLifecycle(t *testing.T) {
}
seen[addr] = true
return addr
-
},
+ checkInitialGroups: checkInitialIPv6Groups,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, true)
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
- sentReportStat := test.sentReportStat(s)
var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ sentReportStat := test.sentReportStat(s)
for _, a := range test.multicastAddrs {
if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
@@ -949,7 +1085,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
t.Fatalf("DisableNIC(%d): %s", nicID, err)
}
sentLeaveStat := test.sentLeaveStat(s)
- leaveCounter := uint64(len(test.multicastAddrs))
+ leaveCounter += uint64(len(test.multicastAddrs))
if got := sentLeaveStat.Value(); got != leaveCounter {
t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
}
@@ -1051,7 +1187,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
clock.Advance(test.maxUnsolicitedResponseDelay)
reportCounter++
if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter)
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -1067,3 +1203,59 @@ func TestMGPWithNICLifecycle(t *testing.T) {
})
}
}
+
+// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
+// performed on loopback interfaces since they have no neighbours.
+func TestMGPDisabledOnLoopback(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
+
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index c53698a6a..f3ad40fdf 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -16,6 +16,8 @@ package tcpip
import (
"sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// SocketOptionsHandler holds methods that help define endpoint specific
@@ -37,6 +39,15 @@ type SocketOptionsHandler interface {
// OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint.
OnCorkOptionSet(v bool)
+
+ // LastError is invoked when SO_ERROR is read for an endpoint.
+ LastError() *Error
+
+ // UpdateLastError updates the endpoint specific last error field.
+ UpdateLastError(err *Error)
+
+ // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
+ HasNIC(v int32) bool
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -60,6 +71,19 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet.
func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
+// LastError implements SocketOptionsHandler.LastError.
+func (*DefaultSocketOptionsHandler) LastError() *Error {
+ return nil
+}
+
+// UpdateLastError implements SocketOptionsHandler.UpdateLastError.
+func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {}
+
+// HasNIC implements SocketOptionsHandler.HasNIC.
+func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
+ return false
+}
+
// SocketOptions contains all the variables which store values for SOL_SOCKET,
// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
@@ -69,24 +93,24 @@ type SocketOptions struct {
// These fields are accessed and modified using atomic operations.
- // broadcastEnabled determines whether datagram sockets are allowed to send
- // packets to a broadcast address.
+ // broadcastEnabled determines whether datagram sockets are allowed to
+ // send packets to a broadcast address.
broadcastEnabled uint32
- // passCredEnabled determines whether SCM_CREDENTIALS socket control messages
- // are enabled.
+ // passCredEnabled determines whether SCM_CREDENTIALS socket control
+ // messages are enabled.
passCredEnabled uint32
// noChecksumEnabled determines whether UDP checksum is disabled while
// transmitting for this socket.
noChecksumEnabled uint32
- // reuseAddressEnabled determines whether Bind() should allow reuse of local
- // address.
+ // reuseAddressEnabled determines whether Bind() should allow reuse of
+ // local address.
reuseAddressEnabled uint32
- // reusePortEnabled determines whether to permit multiple sockets to be bound
- // to an identical socket address.
+ // reusePortEnabled determines whether to permit multiple sockets to be
+ // bound to an identical socket address.
reusePortEnabled uint32
// keepAliveEnabled determines whether TCP keepalive is enabled for this
@@ -94,7 +118,7 @@ type SocketOptions struct {
keepAliveEnabled uint32
// multicastLoopEnabled determines whether multicast packets sent over a
- // non-loopback interface will be looped back. Analogous to inet->mc_loop.
+ // non-loopback interface will be looped back.
multicastLoopEnabled uint32
// receiveTOSEnabled is used to specify if the TOS ancillary message is
@@ -130,6 +154,28 @@ type SocketOptions struct {
// corkOptionEnabled is used to specify if data should be held until segments
// are full by the TCP transport protocol.
corkOptionEnabled uint32
+
+ // receiveOriginalDstAddress is used to specify if the original destination of
+ // the incoming packet should be returned as an ancillary message.
+ receiveOriginalDstAddress uint32
+
+ // recvErrEnabled determines whether extended reliable error message passing
+ // is enabled.
+ recvErrEnabled uint32
+
+ // errQueue is the per-socket error queue. It is protected by errQueueMu.
+ errQueueMu sync.Mutex `state:"nosave"`
+ errQueue sockErrorList
+
+ // bindToDevice determines the device to which the socket is bound.
+ bindToDevice int32
+
+ // mu protects the access to the below fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // linger determines the amount of time the socket should linger before
+ // close. We currently implement this option for TCP socket only.
+ linger LingerOption
}
// InitHandler initializes the handler. This must be called before using the
@@ -146,6 +192,11 @@ func storeAtomicBool(addr *uint32, v bool) {
atomic.StoreUint32(addr, val)
}
+// SetLastError sets the last error for a socket.
+func (so *SocketOptions) SetLastError(err *Error) {
+ so.handler.UpdateLastError(err)
+}
+
// GetBroadcast gets value for SO_BROADCAST option.
func (so *SocketOptions) GetBroadcast() bool {
return atomic.LoadUint32(&so.broadcastEnabled) != 0
@@ -302,3 +353,168 @@ func (so *SocketOptions) SetCorkOption(v bool) {
storeAtomicBool(&so.corkOptionEnabled, v)
so.handler.OnCorkOptionSet(v)
}
+
+// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) GetReceiveOriginalDstAddress() bool {
+ return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0
+}
+
+// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
+ storeAtomicBool(&so.receiveOriginalDstAddress, v)
+}
+
+// GetRecvError gets value for IP*_RECVERR option.
+func (so *SocketOptions) GetRecvError() bool {
+ return atomic.LoadUint32(&so.recvErrEnabled) != 0
+}
+
+// SetRecvError sets value for IP*_RECVERR option.
+func (so *SocketOptions) SetRecvError(v bool) {
+ storeAtomicBool(&so.recvErrEnabled, v)
+ if !v {
+ so.pruneErrQueue()
+ }
+}
+
+// GetLastError gets value for SO_ERROR option.
+func (so *SocketOptions) GetLastError() *Error {
+ return so.handler.LastError()
+}
+
+// GetOutOfBandInline gets value for SO_OOBINLINE option.
+func (*SocketOptions) GetOutOfBandInline() bool {
+ return true
+}
+
+// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not
+// support disabling this option.
+func (*SocketOptions) SetOutOfBandInline(bool) {}
+
+// GetLinger gets value for SO_LINGER option.
+func (so *SocketOptions) GetLinger() LingerOption {
+ so.mu.Lock()
+ linger := so.linger
+ so.mu.Unlock()
+ return linger
+}
+
+// SetLinger sets value for SO_LINGER option.
+func (so *SocketOptions) SetLinger(linger LingerOption) {
+ so.mu.Lock()
+ so.linger = linger
+ so.mu.Unlock()
+}
+
+// SockErrOrigin represents the constants for error origin.
+type SockErrOrigin uint8
+
+const (
+ // SockExtErrorOriginNone represents an unknown error origin.
+ SockExtErrorOriginNone SockErrOrigin = iota
+
+ // SockExtErrorOriginLocal indicates a local error.
+ SockExtErrorOriginLocal
+
+ // SockExtErrorOriginICMP indicates an IPv4 ICMP error.
+ SockExtErrorOriginICMP
+
+ // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error.
+ SockExtErrorOriginICMP6
+)
+
+// IsICMPErr indicates if the error originated from an ICMP error.
+func (origin SockErrOrigin) IsICMPErr() bool {
+ return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6
+}
+
+// SockError represents a queue entry in the per-socket error queue.
+//
+// +stateify savable
+type SockError struct {
+ sockErrorEntry
+
+ // Err is the error caused by the errant packet.
+ Err *Error
+ // ErrOrigin indicates the error origin.
+ ErrOrigin SockErrOrigin
+ // ErrType is the type in the ICMP header.
+ ErrType uint8
+ // ErrCode is the code in the ICMP header.
+ ErrCode uint8
+ // ErrInfo is additional info about the error.
+ ErrInfo uint32
+
+ // Payload is the errant packet's payload.
+ Payload []byte
+ // Dst is the original destination address of the errant packet.
+ Dst FullAddress
+ // Offender is the original sender address of the errant packet.
+ Offender FullAddress
+ // NetProto is the network protocol being used to transmit the packet.
+ NetProto NetworkProtocolNumber
+}
+
+// pruneErrQueue resets the queue.
+func (so *SocketOptions) pruneErrQueue() {
+ so.errQueueMu.Lock()
+ so.errQueue.Reset()
+ so.errQueueMu.Unlock()
+}
+
+// DequeueErr dequeues a socket extended error from the error queue and returns
+// it. Returns nil if queue is empty.
+func (so *SocketOptions) DequeueErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+
+ err := so.errQueue.Front()
+ if err != nil {
+ so.errQueue.Remove(err)
+ }
+ return err
+}
+
+// PeekErr returns the error in the front of the error queue. Returns nil if
+// the error queue is empty.
+func (so *SocketOptions) PeekErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ return so.errQueue.Front()
+}
+
+// QueueErr inserts the error at the back of the error queue.
+//
+// Preconditions: so.GetRecvError() == true.
+func (so *SocketOptions) QueueErr(err *SockError) {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ so.errQueue.PushBack(err)
+}
+
+// QueueLocalErr queues a local error onto the local queue.
+func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
+ so.QueueErr(&SockError{
+ Err: err,
+ ErrOrigin: SockExtErrorOriginLocal,
+ ErrInfo: info,
+ Payload: payload,
+ Dst: dst,
+ NetProto: net,
+ })
+}
+
+// GetBindToDevice gets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) GetBindToDevice() int32 {
+ return atomic.LoadInt32(&so.bindToDevice)
+}
+
+// SetBindToDevice sets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error {
+ if !so.handler.HasNIC(bindToDevice) {
+ return ErrUnknownDevice
+ }
+
+ atomic.StoreInt32(&so.bindToDevice, bindToDevice)
+ return nil
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 9cc6074da..bb30556cf 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -148,7 +148,6 @@ go_test(
],
library = ":stack",
deps = [
- "//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 6e4f5fa46..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -82,12 +82,16 @@ func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool)
}
// ForEachPrimaryEndpoint calls f for each primary address.
-func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
+//
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
a.mu.RLock()
defer a.mu.RUnlock()
for _, ep := range a.mu.primary {
- f(ep)
+ if !f(ep) {
+ return
+ }
}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 5ec9b3411..93e8e1c51 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -560,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) {
}
}
+func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 50 * time.Millisecond,
+ onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) {
+ // Don't resolve the link address.
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */)
+
+ const numPackets int = 5
+ // These packets will all be enqueued in the packet queue to wait for link
+ // address resolution.
+ for i := 0; i < numPackets; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
+
+ // All packets should fail resolution.
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
+ for i := 0; i < numPackets; i++ {
+ select {
+ case got := <-ep2.C:
+ t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+}
+
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index c9b13cd0e..792f4f170 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -18,7 +18,6 @@ import (
"fmt"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -58,9 +57,6 @@ const (
incomplete entryState = iota
// ready means that the address has been resolved and can be used.
ready
- // failed means that address resolution timed out and the address
- // could not be resolved.
- failed
)
// String implements Stringer.
@@ -70,8 +66,6 @@ func (s entryState) String() string {
return "incomplete"
case ready:
return "ready"
- case failed:
- return "failed"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -80,40 +74,48 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ // linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
+ // TODO(gvisor.dev/issue/5150): move these fields under mu.
+ // mu protects the fields below.
+ mu sync.RWMutex
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
- // wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of incomplete these waiters are notified.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil iff
- // s is incomplete and resolution is not yet in progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
}
-// changeState sets the entry's state to ns, notifying any waiters.
+func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
+ for _, callback := range e.onResolve {
+ callback(linkAddr, len(linkAddr) != 0)
+ }
+ e.onResolve = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// changeStateLocked sets the entry's state to ns.
//
// The entry's expiration is bumped up to the greater of itself and the passed
// expiration; the zero value indicates immediate expiration, and is set
// unconditionally - this is an implementation detail that allows for entries
// to be reused.
-func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
- // Notify whoever is waiting on address resolution when transitioning
- // out of incomplete.
- if e.s == incomplete && ns != incomplete {
- for w := range e.wakers {
- w.Assert()
- }
- e.wakers = nil
- if ch := e.done; ch != nil {
- close(ch)
- }
- e.done = nil
+//
+// Precondition: e.mu must be locked
+func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
+ if e.s == incomplete && ns == ready {
+ e.notifyCompletionLocked(e.linkAddr)
}
if expiration.IsZero() || expiration.After(e.expiration) {
@@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
e.s = ns
}
-func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
- delete(e.wakers, w)
-}
-
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
@@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
c.cache.Lock()
entry := c.getOrCreateEntryLocked(k)
- entry.linkAddr = v
-
- entry.changeState(ready, expiration)
c.cache.Unlock()
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+ entry.linkAddr = v
+ entry.changeStateLocked(ready, expiration)
}
// getOrCreateEntryLocked retrieves a cache entry associated with k. The
@@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
var entry *linkAddrEntry
if len(c.cache.table) == linkAddrCacheSize {
entry = c.cache.lru.Back()
+ entry.mu.Lock()
delete(c.cache.table, entry.addr)
c.cache.lru.Remove(entry)
- // Wake waiters and mark the soon-to-be-reused entry as expired. Note
- // that the state passed doesn't matter when the zero time is passed.
- entry.changeState(failed, time.Time{})
+ // Wake waiters and mark the soon-to-be-reused entry as expired.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ entry.mu.Unlock()
} else {
entry = new(linkAddrEntry)
}
@@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ if onResolve != nil {
+ onResolve(addr, true)
+ }
return addr, nil, nil
}
}
@@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
+ case ready:
if !time.Now().After(entry.expiration) {
// Not expired.
- switch s {
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ if onResolve != nil {
+ onResolve(entry.linkAddr, true)
}
+ return entry.linkAddr, nil, nil
}
- entry.changeState(incomplete, time.Time{})
+ entry.changeStateLocked(incomplete, time.Time{})
fallthrough
case incomplete:
- if waker != nil {
- if entry.wakers == nil {
- entry.wakers = make(map[*sleep.Waker]struct{})
- }
- entry.wakers[waker] = struct{}{}
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
}
-
if entry.done == nil {
- // Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- }
-
entry.done = make(chan struct{})
go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
-
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker previously added through get().
-func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.cache.Lock()
- defer c.cache.Unlock()
-
- if entry, ok := c.cache.table[k]; ok {
- entry.removeWaker(waker)
- }
-}
-
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
@@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
}
}
-// checkLinkRequest checks whether previous attempt to resolve address has succeeded
-// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
-// can stop, false if another request should be sent.
+// checkLinkRequest checks whether previous attempt to resolve address has
+// succeeded and mark the entry accordingly. Returns true if request can stop,
+// false if another request should be sent.
func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
c.cache.Lock()
defer c.cache.Unlock()
@@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
// Entry was evicted from the cache.
return true
}
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
- // Entry was made ready by resolver or failed. Either way we're done.
+ case ready:
+ // Entry was made ready by resolver.
case incomplete:
if attempt+1 < c.resolutionAttempts {
// No response yet, need to send another ARP request.
return false
}
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed, now.Add(c.ageLimit))
+ // Max number of retries reached, delete entry.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ delete(c.cache.table, k)
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index d2e37f38d..6883045b5 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -21,7 +21,6 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -50,6 +49,7 @@ type testLinkAddressResolver struct {
}
func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
@@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
}
func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
+ var attemptedResolution bool
for {
- if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- return got, err
+ got, ch, err := c.get(addr, linkRes, "", nil, nil)
+ if err == tcpip.ErrWouldBlock {
+ if attemptedResolution {
+ return got, tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
+ <-ch
+ continue
}
- s.Fetch(true)
+ return got, err
}
}
@@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
+ c.cache.Lock()
+ defer c.cache.Unlock()
for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
e := testAddrs[i]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
}
func TestCacheConcurrent(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
for r := 0; r < 16; r++ {
@@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) {
go func() {
for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
- c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
wg.Done()
}()
@@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) {
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) {
}
e = testAddrs[0]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
+
e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err)
}
}
@@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}
-
-// TestCacheWaker verifies that RemoveWaker removes a waker previously added
-// through get().
-func TestCacheWaker(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
-
- // First, sanity check that wakers are working.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 1
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[0]
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
- id, ok := s.Fetch(true /* block */)
- if !ok {
- t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)")
- }
- if id != wakerID {
- t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID)
- }
-
- if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
- }
-
- // Check that RemoveWaker works.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 2 // different than the ID used in the sanity check
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[1]
- linkRes.onLinkAddressRequest = func() {
- // Remove the waker before the linkAddrCache has the opportunity to send
- // a notification.
- c.removeWaker(e.addr, &w)
- }
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
-
- if got, err := getBlocking(c, e.addr, linkRes); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Fatalf("unexpected notification from waker with id %d", id)
- }
- }
-}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 31b67b987..61636cae5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -540,8 +540,8 @@ func TestDADResolve(t *testing.T) {
// Make sure the right remote link address is used.
snmc := header.SolicitedNodeAddr(addr1)
- if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
- t.Errorf("got remote link address = %s, want = %s", got, want)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
// Check NDP NS packet.
@@ -577,11 +577,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
}
@@ -623,11 +623,11 @@ func TestDADFail(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
@@ -1011,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
return stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -5197,8 +5197,8 @@ func TestRouterSolicitation(t *testing.T) {
}
// Make sure the right remote link address is used.
- if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want {
- t.Errorf("got remote link address = %s, want = %s", got, want)
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 317f6871d..c15f10e76 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -99,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
n.dynamic.lru.Remove(e)
n.dynamic.count--
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Unknown)
- e.notifyWakersLocked()
+ e.removeLocked()
e.mu.Unlock()
}
n.cache[remoteAddr] = entry
@@ -110,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
return entry
}
-// entry looks up the neighbor cache for translating address to link address
-// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
-// is a LinkAddressResolver registered with the network protocol, the cache
-// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
-// provided, it will be notified when address resolution is complete (success
-// or not).
+// entry looks up neighbor information matching the remote address, and returns
+// it if readily available.
+//
+// Returns ErrWouldBlock if the link address is not readily available, along
+// with a notification channel for the caller to block on. Triggers address
+// resolution asynchronously.
+//
+// If onResolve is provided, it will be called either immediately, if resolution
+// is not required, or when address resolution is complete, with the resolved
+// link address and whether resolution succeeded. After any callbacks have been
+// called, the returned notification channel is closed.
+//
+// NB: if a callback is provided, it should not call into the neighbor cache.
//
// If specified, the local address must be an address local to the interface the
// neighbor cache belongs to. The local address is the source address of a
// packet prompting NUD/link address resolution.
//
-// If address resolution is required, ErrNoLinkAddress and a notification
-// channel is returned for the top level caller to block. Channel is closed
-// once address resolution is complete (success or not).
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve.
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
Addr: remoteAddr,
@@ -132,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
State: Static,
UpdatedAtNanos: 0,
}
+ if onResolve != nil {
+ onResolve(linkAddr, true)
+ }
return e, nil, nil
}
@@ -149,37 +155,25 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// of packets to a neighbor. While reasserting a neighbor's reachability,
// a node continues sending packets to that neighbor using the cached
// link-layer address."
+ if onResolve != nil {
+ onResolve(entry.neigh.LinkAddr, true)
+ }
return entry.neigh, nil, nil
- case Unknown, Incomplete:
- entry.addWakerLocked(w)
-
+ case Unknown, Incomplete, Failed:
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
+ }
if entry.done == nil {
// Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
- }
entry.done = make(chan struct{})
}
-
entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
- case Failed:
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
- n.mu.Lock()
- if entry, ok := n.cache[addr]; ok {
- delete(entry.wakers, waker)
- }
- n.mu.Unlock()
-}
-
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
n.mu.RLock()
@@ -222,34 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
return
}
- // Notify that resolution has been interrupted, just in case the entry was
- // in the Incomplete or Probe state.
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
-// removeEntryLocked removes the specified entry from the neighbor cache.
-//
-// Prerequisite: n.mu and entry.mu MUST be locked.
-func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
- if entry.neigh.State != Static {
- n.dynamic.lru.Remove(entry)
- n.dynamic.count--
- }
- if entry.neigh.State != Failed {
- entry.dispatchRemoveEventLocked()
- }
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
-
- delete(n.cache, entry.neigh.Addr)
-}
-
// removeEntry removes a dynamic or static entry by address from the neighbor
// cache. Returns true if the entry was found and deleted.
func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
@@ -264,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
entry.mu.Lock()
defer entry.mu.Unlock()
- n.removeEntryLocked(entry)
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+
+ entry.removeLocked()
+ delete(n.cache, entry.neigh.Addr)
return true
}
@@ -275,9 +254,7 @@ func (n *neighborCache) clear() {
for _, entry := range n.cache {
entry.mu.Lock()
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 732a299f7..a2ed6ae2a 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -28,7 +28,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
)
@@ -190,15 +189,18 @@ type testNeighborResolver struct {
entries *testEntryStore
delay time.Duration
onLinkAddressRequest func()
+ dropReplies bool
}
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
- // Delay handling the request to emulate network latency.
- r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(targetAddr)
- })
+ if !r.dropReplies {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(targetAddr)
+ })
+ }
// Execute post address resolution action, if available.
if f := r.onLinkAddressRequest; f != nil {
@@ -291,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -327,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -354,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -413,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -461,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -513,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
}
// Expect to find only the most recent entries. The order of entries reported
- // by entries() is undeterministic, so entries have to be sorted before
+ // by entries() is nondeterministic, so entries have to be sorted before
// comparison.
wantUnsortedEntries := opts.wantStaticEntries
for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
@@ -575,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
@@ -650,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -694,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -756,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -826,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -907,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
}
}
-func TestNeighborCacheNotifiesWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- id, ok := s.Fetch(false /* block */)
- if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
- }
- if id != wakerID {
- t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
-func TestNeighborCacheRemoveWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
-
- // Remove the waker before the neighbor cache has the opportunity to send a
- // notification.
- neigh.removeWaker(entry.Addr, &w)
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Errorf("unexpected notification from waker with id %d", id)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
config := DefaultNUDConfigurations()
// Stay in Reachable so the cache can overflow
@@ -1062,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1075,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
@@ -1129,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) {
// Add a dynamic entry.
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1187,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) {
}
}
- // Clear shoud remove both dynamic and static entries.
+ // Clear should remove both dynamic and static entries.
neigh.clear()
// Remove events dispatched from clear() have no deterministic order so they
@@ -1234,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -1318,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
frequentlyUsedEntry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
// The following logic is very similar to overflowCache, but
@@ -1330,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
@@ -1373,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1381,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1435,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
}
// Expect to find only the frequently used entry and the most recent entries.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
@@ -1494,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) {
go func(entry NeighborEntry) {
defer wg.Done()
if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
- // Wait for all gorountines to send a request
+ // Wait for all goroutines to send a request
wg.Wait()
// Process all the requests for a single entry concurrently
@@ -1509,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than can fit in
// the cache. Our eviction strategy requires that the last entries are
// present, up to the size of the neighbor cache, and the rest are missing.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
var wantUnsortedEntries []NeighborEntry
for i := store.size() - neighborCacheSize; i < store.size(); i++ {
@@ -1547,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) {
// Add an entry
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err)
}
if t.Failed() {
t.FailNow()
@@ -1578,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1587,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
entry, ok := store.entry(1)
if !ok {
- t.Fatalf("store.entry(1) not found")
+ t.Fatal("store.entry(1) not found")
}
updatedLinkAddr = entry.LinkAddr
}
@@ -1604,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1612,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
@@ -1622,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) {
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1630,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1654,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
},
}
- // First, sanity check that resolution is working
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ // First, sanity check that resolution is working
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
- clock.Advance(typicalLatency)
+
got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1673,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
- // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ // Verify address resolution fails for an unknown address.
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1714,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+}
+
+// TestNeighborCacheRetryResolution simulates retrying communication after
+// failing to perform address resolution.
+func TestNeighborCacheRetryResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ // Simulate a faulty link.
+ dropReplies: true,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatal("store.entry(0) not found")
+ }
+
+ // Perform address resolution with a faulty link, which will fail.
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+ }
+
+ // Verify the entry is in Failed state.
+ wantEntries := []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Failed,
+ },
+ }
+ if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Retry address resolution with a working link.
+ linkRes.dropReplies = false
+ {
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ if incompleteEntry.State != Incomplete {
+ t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
+ }
+ clock.Advance(typicalLatency)
+
+ select {
+ case <-ch:
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err)
+ }
+ if reachableEntry.Addr != entry.Addr {
+ t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr)
+ }
+ if reachableEntry.LinkAddr != entry.LinkAddr {
+ t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr)
+ }
+ if reachableEntry.State != Reachable {
+ t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String())
+ }
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
}
@@ -1742,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
Addr: testEntryBroadcastAddr,
@@ -1750,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1775,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ b.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
- if doneCh != nil {
- <-doneCh
+
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 32399b4f5..75afb3001 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -19,7 +19,6 @@ import (
"sync"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -67,8 +66,7 @@ const (
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
- // Failed means traffic should not be sent to this neighbor since attempts of
- // reachability have returned inconclusive.
+ // Failed means recent attempts of reachability have returned inconclusive.
Failed
)
@@ -93,16 +91,13 @@ type neighborEntry struct {
neigh NeighborEntry
- // wakers is a set of waiters for address resolution result. Anytime state
- // transitions out of incomplete these waiters are notified. It is nil iff
- // address resolution is ongoing and no clients are waiting for the result.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil
- // iff nudState is not Reachable and address resolution is not yet in
- // progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
+
isRouter bool
job *tcpip.Job
}
@@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
}
}
-// addWaker adds w to the list of wakers waiting for address resolution.
-// Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
- if w == nil {
- return
- }
- if e.wakers == nil {
- e.wakers = make(map[*sleep.Waker]struct{})
- }
- e.wakers[w] = struct{}{}
-}
-
-// notifyWakersLocked notifies those waiting for address resolution, whether it
-// succeeded or failed. Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) notifyWakersLocked() {
- for w := range e.wakers {
- w.Assert()
+// notifyCompletionLocked notifies those waiting for address resolution, with
+// the link address if resolution completed successfully.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ for _, callback := range e.onResolve {
+ callback(e.neigh.LinkAddr, succeeded)
}
- e.wakers = nil
+ e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
@@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
@@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() {
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
@@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() {
// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
// has been removed.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
+// cancelJobLocked cancels the currently scheduled action, if there is one.
+// Entries in Unknown, Stale, or Static state do not have a scheduled action.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) cancelJobLocked() {
+ if job := e.job; job != nil {
+ job.Cancel()
+ }
+}
+
+// removeLocked prepares the entry for removal.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) removeLocked() {
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.dispatchRemoveEventLocked()
+ e.cancelJobLocked()
+ e.notifyCompletionLocked(false /* succeeded */)
+}
+
// setStateLocked transitions the entry to the specified state immediately.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
-// e.mu MUST be locked.
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
- // Cancel the previously scheduled action, if there is one. Entries in
- // Unknown, Stale, or Static state do not have scheduled actions.
- if timer := e.job; timer != nil {
- timer.Cancel()
- }
+ e.cancelJobLocked()
prev := e.neigh.State
e.neigh.State = next
@@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
e.job.Schedule(immediateDuration)
case Failed:
- e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() {
- e.nic.neigh.removeEntryLocked(e)
- })
- e.job.Schedule(config.UnreachableTime)
+ e.notifyCompletionLocked(false /* succeeded */)
case Unknown, Stale, Static:
// Do nothing
@@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
+
+ fallthrough
case Unknown:
e.neigh.State = Incomplete
e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
@@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// implementation may find it convenient in some cases to return errors
// to the sender by taking the offending packet, generating an ICMP
// error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
+ // error-handling routines." - RFC 4861 section 2.1
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
@@ -349,8 +358,6 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
- case Failed:
- e.nic.stats.Neighbor.FailedEntryLookups.Increment()
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -360,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// Neighbor Solicitation for ARP or NDP, respectively).
//
// Follows the logic defined in RFC 4861 section 7.2.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// Probes MUST be silently discarded if the target address is tentative, does
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
switch e.neigh.State {
- case Unknown, Incomplete, Failed:
+ case Unknown, Failed:
e.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
- e.notifyWakersLocked()
e.dispatchAddEventLocked()
+ case Incomplete:
+ // "If an entry already exists, and the cached link-layer address
+ // differs from the one in the received Source Link-Layer option, the
+ // cached address should be replaced by the received address, and the
+ // entry's reachability state MUST be set to STALE."
+ // - RFC 4861 section 7.2.3
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.setStateLocked(Stale)
+ e.notifyCompletionLocked(true /* succeeded */)
+ e.dispatchChangeEventLocked()
+
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
@@ -404,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// not be possible. SEND uses RSA key pairs to produce Cryptographically
// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
// claimed source of an NDP message is the owner of the claimed address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
switch e.neigh.State {
case Incomplete:
@@ -422,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
@@ -457,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
if !wasReachable {
e.dispatchChangeEventLocked()
}
@@ -495,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
// (e.g. TCP acknowledgements) reachability confirmation.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
@@ -512,23 +535,3 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
}
-
-// doubleLock combines two locks into one while maintaining lock ordering.
-//
-// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed
-// neighbor is allowed.
-type doubleLock struct {
- first, second sync.Locker
-}
-
-// Lock locks both locks in order: first then second.
-func (l *doubleLock) Lock() {
- l.first.Lock()
- l.second.Lock()
-}
-
-// Unlock unlocks both locks in reverse order: second then first.
-func (l *doubleLock) Unlock() {
- l.second.Unlock()
- l.first.Unlock()
-}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index c497d3932..ec34ffa5a 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -25,7 +25,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -73,36 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option {
// The following unit tests exercise every state transition and verify its
// behavior with RFC 4681.
//
-// | From | To | Cause | Action | Event |
-// | ========== | ========== | ========================================== | =============== | ======= |
-// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
-// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
-// | Unknown | Stale | Probe w/ unknown address | | Added |
-// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
-// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
-// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
-// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
-// | Reachable | Stale | Reachable timer expired | | Changed |
-// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
-// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
-// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet queued | | Changed |
-// | Delay | Reachable | Upper-layer confirmation | | Changed |
-// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
-// | Delay | Probe | Delay timer expired | Send probe | Changed |
-// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
-// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
-// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
-// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Failed | Failed | Packet queued | | |
-// | Failed | | Unreachability timer expired | Delete entry | |
+// | From | To | Cause | Update | Action | Event |
+// | ========== | ========== | ========================================== | ======== | ===========| ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added |
+// | Unknown | Stale | Probe | | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed |
+// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | |
+// | Reachable | Stale | Reachable timer expired | | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Stale | Stale | Override confirmation | LinkAddr | | Changed |
+// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed |
+// | Stale | Delay | Packet sent | | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | | Changed |
+// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Delay | Probe | Delay timer expired | | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed |
+// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Probe | Probe | Retransmit timer expired | | | Changed |
+// | Probe | Failed | Max probes sent without reply | | Notify | Removed |
+// | Failed | Incomplete | Packet queued | | Send probe | Added |
type testEntryEventType uint8
@@ -258,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -291,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -320,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -367,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -406,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
updatedAtNanos := e.neigh.UpdatedAtNanos
e.mu.Unlock()
@@ -560,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) {
nudDisp.mu.Unlock()
}
-// TestEntryAddsAndClearsWakers verifies that wakers are added when
-// addWakerLocked is called and cleared when address resolution finishes. In
-// this case, address resolution will finish when transitioning from Incomplete
-// to Reachable.
-func TestEntryAddsAndClearsWakers(t *testing.T) {
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ }
e.mu.Unlock()
runImmediatelyScheduledJobs(clock)
@@ -593,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Lock()
- if got := e.wakers; got != nil {
- t.Errorf("got e.wakers = %v, want = nil", got)
- }
- e.addWakerLocked(&w)
- if got, want := w.IsAsserted(), false; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
- }
- if e.wakers == nil {
- t.Error("expected e.wakers to be non-nil")
- }
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
- IsRouter: false,
+ IsRouter: true,
})
- if e.wakers != nil {
- t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
- if got, want := w.IsAsserted(), true; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
}
e.mu.Unlock()
@@ -643,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -663,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- linkRes.mu.Unlock()
e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
+ Solicited: false,
Override: false,
- IsRouter: true,
+ IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
- }
- if !e.isRouter {
- t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -698,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
Entry: NeighborEntry{
Addr: entryTestAddr1,
LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ State: Stale,
},
},
}
@@ -709,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToStale(t *testing.T) {
+func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -736,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) {
}
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
+ e.handleProbeLocked(entryTestLinkAddr1)
if e.neigh.State != Stale {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
@@ -780,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -841,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
e.mu.Unlock()
}
@@ -885,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
Override: false,
IsRouter: true,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.isRouter, true; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
@@ -932,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
}
@@ -1083,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
}
@@ -2381,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
e.mu.Unlock()
@@ -2447,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.mu.Unlock()
}
@@ -2505,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2620,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2740,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2836,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2964,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -3101,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -3435,212 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryFailedToFailed(t *testing.T) {
+func TestEntryFailedToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
- c.MaxUnicastProbes = 3
e, nudDisp, linkRes, clock := entryTestSetup(c)
- // Verify the cache contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
- t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
- }
-
// TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
// their expected state.
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.mu.Unlock()
-
- runImmediatelyScheduledJobs(clock)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
clock.Advance(waitFor)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
- }
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- },
- },
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- },
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
},
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
- },
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
},
{
- EventType: entryTestRemoved,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
- },
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
},
}
- nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-
- failedLookups := e.nic.stats.Neighbor.FailedEntryLookups
- if got := failedLookups.Value(); got != 0 {
- t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got)
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
e.mu.Lock()
- // Verify queuing a packet to the entry immediately fails.
- e.handlePacketQueuedLocked(entryTestAddr2)
- state := e.neigh.State
- e.mu.Unlock()
- if state != Failed {
- t.Errorf("got e.neigh.State = %q, want = %q", state, Failed)
- }
-
- if got := failedLookups.Value(); got != 1 {
- t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got)
- }
-}
-
-func TestEntryFailedGetsDeleted(t *testing.T) {
- c := DefaultNUDConfigurations()
- c.MaxMulticastProbes = 3
- c.MaxUnicastProbes = 3
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- // Verify the cache contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
- t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
-
- e.mu.Lock()
- e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- runImmediatelyScheduledJobs(clock)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
- }
-
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
e.handlePacketQueuedLocked(entryTestAddr2)
- e.mu.Unlock()
-
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
- clock.Advance(waitFor)
- {
- wantProbes := []entryTestProbeInfo{
- // The next three probe are sent in Probe.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
+ e.mu.Unlock()
wantEvents := []testEntryEventInfo{
{
@@ -3653,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
},
},
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- },
- },
- {
- EventType: entryTestChanged,
+ EventType: entryTestRemoved,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
}
@@ -3694,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- // Verify the cache no longer contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
- t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
- }
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 5887aa1ed..4a34805b5 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -20,7 +20,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -54,9 +53,9 @@ type NIC struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ // packetEPs is protected by mu, but the contained packetEndpointList are
+ // not.
+ packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
@@ -82,6 +81,39 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
+type packetEndpointList struct {
+ mu sync.RWMutex
+
+ // eps is protected by mu, but the contained PacketEndpoint values are not.
+ eps []PacketEndpoint
+}
+
+func (p *packetEndpointList) add(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.eps = append(p.eps, ep)
+}
+
+func (p *packetEndpointList) remove(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for i, epOther := range p.eps {
+ if epOther == ep {
+ p.eps = append(p.eps[:i], p.eps[i+1:]...)
+ break
+ }
+ }
+}
+
+// forEach calls fn with each endpoints in p while holding the read lock on p.
+func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ for _, ep := range p.eps {
+ fn(ep)
+ }
+}
+
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -102,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -125,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = new(packetEndpointList)
}
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
+ nic.mu.packetEPs[netNum] = new(packetEndpointList)
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
@@ -172,7 +204,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -184,6 +216,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.
@@ -258,15 +294,17 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// the same unresolved IP address, and transmit the saved
// packet when the address has been resolved.
//
- // RFC 4861 section 5.2 (for IPv6):
- // Once the IP address of the next-hop node is known, the sender
- // examines the Neighbor Cache for link-layer information about that
- // neighbor. If no entry exists, the sender creates one, sets its state
- // to INCOMPLETE, initiates Address Resolution, and then queues the data
- // packet pending completion of address resolution.
+ // RFC 4861 section 7.2.2 (for IPv6):
+ // While waiting for address resolution to complete, the sender MUST, for
+ // each neighbor, retain a small queue of packets waiting for address
+ // resolution to complete. The queue MUST hold at least one packet, and MAY
+ // contain more. However, the number of queued packets per neighbor SHOULD
+ // be limited to some small value. When a queue overflows, the new arrival
+ // SHOULD replace the oldest entry. Once address resolution completes, the
+ // node transmits any queued packets.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
- r := r.Clone()
+ r.Acquire()
n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
@@ -279,7 +317,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
+ routeInfo: routeInfo{
+ NetProto: protocol,
+ },
}
r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
@@ -508,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
return n.neigh.entries(), nil
}
-func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
- if n.neigh == nil {
- return
- }
-
- n.neigh.removeWaker(addr, w)
-}
-
func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
if n.neigh == nil {
return tcpip.ErrNotSupported
@@ -634,15 +666,23 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
- packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet type sockets that may be listening for all protocols.
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ protoEPs := n.mu.packetEPs[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ // Deliver to interested packet endpoints without holding NIC lock.
+ deliverPacketEPs := func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketHost
ep.HandlePacket(n.id, local, protocol, p)
}
+ if protoEPs != nil {
+ protoEPs.forEach(deliverPacketEPs)
+ }
+ if anyEPs != nil {
+ anyEPs.forEach(deliverPacketEPs)
+ }
// Parse headers.
netProto := n.stack.NetworkProtocolInstance(protocol)
@@ -683,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ eps.forEach(func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
- }
+ })
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -845,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
if !ok {
return tcpip.ErrNotSupported
}
- n.mu.packetEPs[netProto] = append(eps, ep)
+ eps.add(ep)
return nil
}
@@ -858,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
if !ok {
return
}
-
- for i, epOther := range eps {
- if epOther == ep {
- n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
- return
- }
- }
+ eps.remove(ep)
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index ab629b3a4..12d67409a 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -109,14 +109,6 @@ const (
//
// Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
defaultMaxReachbilityConfirmations = 3
-
- // defaultUnreachableTime is the default duration for how long an entry will
- // remain in the FAILED state before being removed from the neighbor cache.
- //
- // Note, there is no equivalent protocol constant defined in RFC 4861. It
- // leaves the specifics of any garbage collection mechanism up to the
- // implementation.
- defaultUnreachableTime = 5 * time.Second
)
// NUDDispatcher is the interface integrators of netstack must implement to
@@ -278,10 +270,6 @@ type NUDConfigurations struct {
// TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
// configuration option is necessary.
MaxReachabilityConfirmations uint32
-
- // UnreachableTime describes how long an entry will remain in the FAILED
- // state before being removed from the neighbor cache.
- UnreachableTime time.Duration
}
// DefaultNUDConfigurations returns a NUDConfigurations populated with default
@@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations {
MaxUnicastProbes: defaultMaxUnicastProbes,
MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
- UnreachableTime: defaultUnreachableTime,
}
}
@@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() {
if c.MaxUnicastProbes == 0 {
c.MaxUnicastProbes = defaultMaxUnicastProbes
}
- if c.UnreachableTime == 0 {
- c.UnreachableTime = defaultUnreachableTime
- }
}
// calcMaxRandomFactor calculates the maximum value of the random factor used
@@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration {
s.config.BaseReachableTime != s.prevBaseReachableTime ||
s.config.MinRandomFactor != s.prevMinRandomFactor ||
s.config.MaxRandomFactor != s.prevMaxRandomFactor {
- return s.recomputeReachableTimeLocked()
+ s.recomputeReachableTimeLocked()
}
return s.reachableTime
}
@@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration {
// random value gets re-computed at least once every few hours.
//
// s.mu MUST be locked for writing.
-func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+func (s *NUDState) recomputeReachableTimeLocked() {
s.prevBaseReachableTime = s.config.BaseReachableTime
s.prevMinRandomFactor = s.config.MinRandomFactor
s.prevMaxRandomFactor = s.config.MaxRandomFactor
@@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
}
s.expiration = time.Now().Add(2 * time.Hour)
- return s.reachableTime
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 8cffb9fc6..7bca1373e 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -37,7 +37,6 @@ const (
defaultMaxUnicastProbes = 3
defaultMaxAnycastDelayTime = time.Second
defaultMaxReachbilityConfirmations = 3
- defaultUnreachableTime = 5 * time.Second
defaultFakeRandomNum = 0.5
)
@@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
}
}
-func TestNUDConfigurationsUnreachableTime(t *testing.T) {
- tests := []struct {
- name string
- unreachableTime time.Duration
- want time.Duration
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- unreachableTime: 0,
- want: defaultUnreachableTime,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- unreachableTime: time.Millisecond,
- want: time.Millisecond,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.UnreachableTime = test.unreachableTime
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- UseNeighborCache: true,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
- }
- if got := sc.UnreachableTime; got != test.want {
- t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
// TestNUDStateReachableTime verifies the correctness of the ReachableTime
// computation.
func TestNUDStateReachableTime(t *testing.T) {
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 5d364a2b0..4a3adcf33 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
for _, p := range packets {
if cancelled {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
- } else if _, err := p.route.Resolve(nil); err != nil {
+ } else if p.route.IsResolutionRequired() {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index b334e27c4..7e83b7fbb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -799,19 +798,26 @@ type LinkAddressCache interface {
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
- // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
- // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
- // registered with the network protocol, the cache attempts to resolve the address
- // and returns ErrWouldBlock. Waker is notified when address resolution is
- // complete (success or not).
+ // GetLinkAddress finds the link address corresponding to the remote address
+ // (e.g. IP -> MAC).
//
- // If address resolution is required, ErrNoLinkAddress and a notification channel is
- // returned for the top level caller to block. Channel is closed once address resolution
- // is complete (success or not).
- GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
-
- // RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ // Returns a link address for the remote address, if readily available.
+ //
+ // Returns ErrWouldBlock if the link address is not readily available, along
+ // with a notification channel for the caller to block on. Triggers address
+ // resolution asynchronously.
+ //
+ // If onResolve is provided, it will be called either immediately, if
+ // resolution is not required, or when address resolution is complete, with
+ // the resolved link address and whether resolution succeeded. After any
+ // callbacks have been called, the returned notification channel is closed.
+ //
+ // If specified, the local address must be an address local to the interface
+ // the neighbor cache belongs to. The local address is the source address of
+ // a packet prompting NUD/link address resolution.
+ //
+ // TODO(gvisor.dev/issue/5151): Don't return the link address.
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
}
// RawFactory produces endpoints for writing various types of raw packets.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index de5fe6ffe..b0251d0b4 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -31,24 +30,7 @@ import (
//
// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
- // RemoteAddress is the final destination of the route.
- RemoteAddress tcpip.Address
-
- // LocalAddress is the local address where the route starts.
- LocalAddress tcpip.Address
-
- // LocalLinkAddress is the link-layer (MAC) address of the
- // where the route starts.
- LocalLinkAddress tcpip.LinkAddress
-
- // NextHop is the next node in the path to the destination.
- NextHop tcpip.Address
-
- // NetProto is the network-layer protocol.
- NetProto tcpip.NetworkProtocolNumber
-
- // Loop controls where WritePacket should send packets.
- Loop PacketLooping
+ routeInfo
// localAddressNIC is the interface the address is associated with.
// TODO(gvisor.dev/issue/4548): Remove this field once we can query the
@@ -78,6 +60,45 @@ type Route struct {
linkRes LinkAddressResolver
}
+type routeInfo struct {
+ // RemoteAddress is the final destination of the route.
+ RemoteAddress tcpip.Address
+
+ // LocalAddress is the local address where the route starts.
+ LocalAddress tcpip.Address
+
+ // LocalLinkAddress is the link-layer (MAC) address of the
+ // where the route starts.
+ LocalLinkAddress tcpip.LinkAddress
+
+ // NextHop is the next node in the path to the destination.
+ NextHop tcpip.Address
+
+ // NetProto is the network-layer protocol.
+ NetProto tcpip.NetworkProtocolNumber
+
+ // Loop controls where WritePacket should send packets.
+ Loop PacketLooping
+}
+
+// RouteInfo contains all of Route's exported fields.
+type RouteInfo struct {
+ routeInfo
+
+ // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ RemoteLinkAddress tcpip.LinkAddress
+}
+
+// GetFields returns a RouteInfo with all of r's exported fields. This allows
+// callers to store the route's fields without retaining a reference to it.
+func (r *Route) GetFields() RouteInfo {
+ return RouteInfo{
+ routeInfo: r.routeInfo,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
+ }
+}
+
// constructAndValidateRoute validates and initializes a route. It takes
// ownership of the provided local address.
//
@@ -152,13 +173,15 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
r := &Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+ routeInfo: routeInfo{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ Loop: loop,
+ },
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
}
r.mu.Lock()
@@ -264,22 +287,21 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
-// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
-// notified when address resolution is complete (success or not).
+// Resolve attempts to resolve the link address if necessary.
//
-// If address resolution is required, ErrNoLinkAddress and a notification channel is
-// returned for the top level caller to block. Channel is closed once address resolution
-// is complete (success or not).
-//
-// The NIC r uses must not be locked.
-func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
+// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
+// waiting for ARP reply). If address resolution is required, a notification
+// channel is also returned for the caller to block on. The channel is closed
+// once address resolution is complete (successful or not). If a callback is
+// provided, it will be called when address resolution is complete, regardless
+// of success or failure.
+func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
r.mu.Lock()
- defer r.mu.Unlock()
if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
+ r.mu.Unlock()
return nil, nil
}
@@ -288,6 +310,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
r.mu.remoteLinkAddress = r.LocalLinkAddress
+ r.mu.Unlock()
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -300,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
+ // Increment the route's reference count because finishResolution retains a
+ // reference to the route and releases it when called.
+ r.acquireLocked()
+ r.mu.Unlock()
+
+ finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
+ if ok {
+ r.ResolveWith(linkAddress)
+ }
+ if afterResolve != nil {
+ afterResolve()
+ }
+ r.Release()
+ }
+
if neigh := r.outgoingNIC.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
+ _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
if err != nil {
return ch, err
}
- r.mu.remoteLinkAddress = entry.LinkAddr
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
+ _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution)
if err != nil {
return ch, err
}
- r.mu.remoteLinkAddress = linkAddr
return nil, nil
}
-// RemoveWaker removes a waker that has been added in Resolve().
-func (r *Route) RemoveWaker(waker *sleep.Waker) {
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
- }
-
- if neigh := r.outgoingNIC.neigh; neigh != nil {
- neigh.removeWaker(nextAddr, waker)
- return
- }
-
- r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
-}
-
// local returns true if the route is a local route.
func (r *Route) local() bool {
return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
@@ -419,46 +440,31 @@ func (r *Route) MTU() uint32 {
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
-// Release frees all resources associated with the route.
+// Release decrements the reference counter of the resources associated with the
+// route.
func (r *Route) Release() {
r.mu.Lock()
defer r.mu.Unlock()
- if r.mu.localAddressEndpoint != nil {
- r.mu.localAddressEndpoint.DecRef()
- r.mu.localAddressEndpoint = nil
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ ep.DecRef()
}
}
-// Clone clones the route.
-func (r *Route) Clone() *Route {
+// Acquire increments the reference counter of the resources associated with the
+// route.
+func (r *Route) Acquire() {
r.mu.RLock()
defer r.mu.RUnlock()
+ r.acquireLocked()
+}
- newRoute := &Route{
- RemoteAddress: r.RemoteAddress,
- LocalAddress: r.LocalAddress,
- LocalLinkAddress: r.LocalLinkAddress,
- NextHop: r.NextHop,
- NetProto: r.NetProto,
- Loop: r.Loop,
- localAddressNIC: r.localAddressNIC,
- outgoingNIC: r.outgoingNIC,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-
- newRoute.mu.Lock()
- defer newRoute.mu.Unlock()
- newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint
- if newRoute.mu.localAddressEndpoint != nil {
- if !newRoute.mu.localAddressEndpoint.IncRef() {
- panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress))
+func (r *Route) acquireLocked() {
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ if !ep.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
}
}
- newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress
-
- return newRoute
}
// Stack returns the instance of the Stack that owns this route.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index dc4f5b3e7..114643b03 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -29,7 +29,6 @@ import (
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -171,6 +170,9 @@ type TCPSenderState struct {
// Outstanding is the number of packets in flight.
Outstanding int
+ // SackedOut is the number of packets which have been selectively acked.
+ SackedOut int
+
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
@@ -1517,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
nic := s.nics[nicID]
if nic == nil {
@@ -1528,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve)
}
// Neighbors returns all IP to MAC address associations.
@@ -1544,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
return nic.neighbors()
}
-// RemoveWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
- if s.useNeighborCache {
- s.mu.RLock()
- nic, ok := s.nics[nicID]
- s.mu.RUnlock()
-
- if ok {
- nic.removeWaker(addr, waker)
- }
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic := s.nics[nicID]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.removeWaker(fullAddr, waker)
- }
-}
-
// AddStaticNeighbor statically associates an IP address to a MAC address.
func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
s.mu.RLock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 457990945..856ebf6d4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1602,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = header.IPv4Any
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1656,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1666,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic2Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1682,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2726,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- nicID = 1
- lifetimeSeconds = 9999
+ globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01")
+ ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02")
+ toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ nicID = 1
+ lifetimeSeconds = 9999
)
prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2744,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix
nicAddrs []tcpip.Address
slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix
- connectAddr tcpip.Address
+ remoteAddr tcpip.Address
expectedLocalAddr tcpip.Address
}{
- // Test Rule 1 of RFC 6724 section 5.
+ // Test Rule 1 of RFC 6724 section 5 (prefer same address).
{
name: "Same Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Same Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
- // Test Rule 2 of RFC 6724 section 5.
+ // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope).
{
name: "Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
+
+ // Test Rule 6 of 6724 section 5 (prefer matching label).
{
name: "Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
+ {
+ name: "Toredo most preferred (first address)",
+ nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "Toredo most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "6To4 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "6To4 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
- // Test Rule 7 of RFC 6724 section 5.
+ // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses).
{
name: "Temp Global most preferred (last address)",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
{
name: "Temp Global most preferred (first address)",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
slaacPrefixForTempAddrAfterNICAddrAdd: prefix1,
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
+ // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix).
+ {
+ name: "Longest prefix matched most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr2, globalAddr1},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+ {
+ name: "Longest prefix matched most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, globalAddr2},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+
// Test returning the endpoint that is closest to the front when
// candidate addresses are "equal" from the perspective of RFC 6724
// section 5.
{
name: "Unique Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Link Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local for Unique Local",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: uniqueLocalAddr2,
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Temp Global for Global",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
slaacPrefixForTempAddrAfterNICAddrAdd: prefix2,
- connectAddr: globalAddr1,
+ remoteAddr: globalAddr1,
expectedLocalAddr: tempGlobalAddr2,
},
}
@@ -2898,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr3,
- NIC: nicID,
- }})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) {
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
@@ -2923,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
t.FailNow()
}
- if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ }
+
+ addressableEndpoint, ok := netEP.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("network endpoint is not addressable")
+ }
+
+ addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */)
+ if addressEP == nil {
+ t.Fatal("expected a non-nil address endpoint")
+ }
+ defer addressEP.DecRef()
+
+ if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr {
t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
}
})
@@ -4204,8 +4286,8 @@ func TestWritePacketToRemote(t *testing.T) {
if got, want := pkt.Proto, test.protocol; got != want {
t.Fatalf("pkt.Proto = %d, want %d", got, want)
}
- if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want {
- t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want)
+ if pkt.Route.RemoteLinkAddress != linkAddr2 {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
}
if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 2cdb5ca79..737d8d912 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -141,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -308,9 +308,8 @@ func TestBindToDeviceDistribution(t *testing.T) {
defer ep.Close()
ep.SocketOptions().SetReusePort(endpoint.reuse)
- bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
- if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
- t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
+ if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
}
var dstAddr tcpip.Address
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index d9769e47d..dd552b8b9 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -77,6 +77,7 @@ func (f *fakeTransportEndpoint) Abort() {
}
func (f *fakeTransportEndpoint) Close() {
+ // TODO(gvisor.dev/issue/5153): Consider retaining the route.
f.route.Release()
}
@@ -109,8 +110,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -146,16 +147,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return tcpip.ErrNoRoute
}
- defer r.Release()
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
+ r.Release()
return err
}
- f.route = r.Clone()
+ f.route = r
return nil
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 2bd472811..ef0f51f1a 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -49,8 +49,9 @@ const ipv4AddressSize = 4
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
-// Note: to support save / restore, it is important that all tcpip errors have
-// distinct error messages.
+// All errors must have unique msg strings.
+//
+// +stateify savable
type Error struct {
msg string
@@ -257,6 +258,44 @@ func (a Address) Unspecified() bool {
return true
}
+// MatchingPrefix returns the matching prefix length in bits.
+//
+// Panics if b and a have different lengths.
+func (a Address) MatchingPrefix(b Address) uint8 {
+ const bitsInAByte = 8
+
+ if len(a) != len(b) {
+ panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b))
+ }
+
+ var prefix uint8
+ for i := range a {
+ aByte := a[i]
+ bByte := b[i]
+
+ if aByte == bByte {
+ prefix += bitsInAByte
+ continue
+ }
+
+ // Count the remaining matching bits in the byte from MSbit to LSBbit.
+ mask := uint8(1) << (bitsInAByte - 1)
+ for {
+ if aByte&mask == bByte&mask {
+ prefix++
+ mask >>= 1
+ continue
+ }
+
+ break
+ }
+
+ break
+ }
+
+ return prefix
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -491,6 +530,17 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+
+ // HasOriginalDestinationAddress indicates whether OriginalDstAddress is
+ // set.
+ HasOriginalDstAddress bool
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress FullAddress
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr *SockError
}
// PacketOwner is used to get UID and GID of the packet.
@@ -545,7 +595,7 @@ type Endpoint interface {
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek([][]byte) (int64, ControlMessages, *Error)
+ Peek([][]byte) (int64, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -905,14 +955,6 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
-// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
-// should bind only on a specific NIC.
-type BindToDeviceOption NICID
-
-func (*BindToDeviceOption) isGettableSocketOption() {}
-
-func (*BindToDeviceOption) isSettableSocketOption() {}
-
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
@@ -1087,14 +1129,6 @@ type RemoveMembershipOption MembershipOption
func (*RemoveMembershipOption) isSettableSocketOption() {}
-// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP out-of-band data is delivered along with the normal in-band data.
-type OutOfBandInlineOption int
-
-func (*OutOfBandInlineOption) isGettableSocketOption() {}
-
-func (*OutOfBandInlineOption) isSettableSocketOption() {}
-
// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
// classic BPF filter on a given endpoint.
type SocketDetachFilterOption int
@@ -1144,10 +1178,6 @@ type LingerOption struct {
Timeout time.Duration
}
-func (*LingerOption) isGettableSocketOption() {}
-
-func (*LingerOption) isSettableSocketOption() {}
-
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index c461da137..9bd563c46 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -270,3 +270,43 @@ func TestAddressUnspecified(t *testing.T) {
})
}
}
+
+func TestAddressMatchingPrefix(t *testing.T) {
+ tests := []struct {
+ addrA Address
+ addrB Address
+ prefix uint8
+ }{
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x01",
+ prefix: 16,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x00",
+ prefix: 15,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x81\x00",
+ prefix: 0,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x80",
+ prefix: 8,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x02\x80",
+ prefix: 6,
+ },
+ }
+
+ for _, test := range tests {
+ if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix {
+ t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix)
+ }
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 8be791a00..2e59f6a42 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -96,11 +96,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -272,11 +272,11 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLen),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 94fcd72d9..d1e4a7cb7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -75,8 +75,6 @@ type endpoint struct {
route *stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -332,21 +330,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.SocketDetachFilterOption:
- return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- }
return nil
}
@@ -399,16 +388,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
@@ -524,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
@@ -539,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
e.ID = id
- e.route = r.Clone()
+ e.route = r
e.RegisterNICID = nicID
e.state = stateConnected
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 3666bac0f..e5e247342 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -85,8 +85,6 @@ type endpoint struct {
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
boundNIC tcpip.NICID
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -206,8 +204,8 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
}
// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
@@ -306,16 +304,10 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- ep.mu.Lock()
- ep.linger = *v
- ep.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -374,18 +366,16 @@ func (ep *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (ep *endpoint) UpdateLastError(err *tcpip.Error) {
+ ep.lastErrorMu.Lock()
+ ep.lastError = err
+ ep.lastErrorMu.Unlock()
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- ep.mu.Lock()
- *o = ep.linger
- ep.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrNotSupported
- }
+ return tcpip.ErrNotSupported
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 0840a4b3d..7befcfc9b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -85,8 +85,6 @@ type endpoint struct {
// Connect(), and is valid only when conneted is true.
route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -227,6 +225,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrInvalidOptionValue
}
+ if opts.To != nil {
+ // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -256,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
e.mu.RLock()
+ defer e.mu.RUnlock()
if e.closed {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
payloadBytes, err := p.FullPayload()
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
@@ -273,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -295,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if e.route.IsResolutionRequired() {
- savedRoute := e.route
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in finishWrite.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Make sure that the route didn't change during the
- // time we didn't hold the lock.
- if !e.connected || savedRoute != e.route {
- e.mu.Unlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
- n, ch, err := e.finishWrite(payloadBytes, savedRoute)
- e.mu.Unlock()
- return n, ch, err
- }
-
- n, ch, err := e.finishWrite(payloadBytes, e.route)
- e.mu.RUnlock()
- return n, ch, err
+ return e.finishWrite(payloadBytes, e.route)
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
@@ -335,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
- e.mu.RUnlock()
return n, ch, err
}
@@ -386,8 +364,8 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -397,6 +375,11 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -425,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer route.Release()
if e.associated {
// Re-register the endpoint with the appropriate NIC.
if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ route.Release()
return err
}
e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
@@ -437,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the route we've connected via.
- e.route = route.Clone()
+ e.route = route
e.connected = true
return nil
@@ -520,16 +503,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -581,16 +558,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -625,6 +593,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ e.mu.RLock()
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -637,6 +606,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// sockets.
if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
return
@@ -644,6 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
@@ -655,11 +626,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// If bound to a NIC, only accept data for that NIC.
if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
// If bound to an address, only accept data for that address.
if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
}
@@ -668,6 +641,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// connected to.
if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
@@ -702,6 +676,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
if wasEmpty {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 3e1041cbe..2d96a65bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index c944dccc0..0dc710276 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error {
func (h *handshake) resolveRoute() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error {
// Initial action is to resolve route.
index := wakerForResolution
+ attemptedResolution := false
for {
switch index {
case wakerForResolution:
- if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
- if err == tcpip.ErrNoLinkAddress {
- h.ep.stats.SendErrors.NoLinkAddr.Increment()
- } else if err != nil {
+ if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock {
+ if err != nil {
h.ep.stats.SendErrors.NoRoute.Increment()
}
// Either success (err == nil) or failure.
return err
}
+ if attemptedResolution {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ return tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
// Resolution not completed. Keep trying...
case wakerForNotification:
n := h.ep.fetchNotifications()
if n&notifyClose != 0 {
- h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error {
// complete completes the TCP 3-way handshake initiated by h.start().
func (h *handshake) complete() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resendWaker := sleep.Waker{}
s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Initialize the sleeper based on the wakers in funcs.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
for i := range funcs {
s.AddWaker(funcs[i].w, i)
}
@@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const notification = 2
const timeWaitDone = 3
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 87eda2efb..6e3c8860e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -502,9 +502,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // bindToDevice is set to the NIC on which to bind or disabled if 0.
- bindToDevice tcpip.NICID
-
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -674,9 +671,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -1040,7 +1034,8 @@ func (e *endpoint) Close() {
return
}
- if e.linger.Enabled && e.linger.Timeout == 0 {
+ linger := e.SocketOptions().GetLinger()
+ if linger.Enabled && linger.Timeout == 0 {
s := e.EndpointState()
isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
if isResetState {
@@ -1305,6 +1300,15 @@ func (e *endpoint) LastError() *tcpip.Error {
return e.lastErrorLocked()
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.LockUser()
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+ e.UnlockUser()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1498,7 +1502,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1506,10 +1510,10 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
- return 0, tcpip.ControlMessages{}, e.hardErrorLocked()
+ return 0, e.hardErrorLocked()
}
e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ return 0, tcpip.ErrInvalidEndpointState
}
e.rcvListMu.Lock()
@@ -1518,9 +1522,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return 0, tcpip.ErrClosedForReceive
}
- return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return 0, tcpip.ErrWouldBlock
}
// Make a copy of vec so we can modify the slide headers.
@@ -1535,7 +1539,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
for len(v) > 0 {
if len(vec) == 0 {
- return num, tcpip.ControlMessages{}, nil
+ return num, nil
}
if len(vec[0]) == 0 {
vec = vec[1:]
@@ -1550,7 +1554,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
}
}
- return num, tcpip.ControlMessages{}, nil
+ return num, nil
}
// selectWindowLocked returns the new window without checking for shrinking or scaling
@@ -1814,18 +1818,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.LockUser()
- e.bindToDevice = id
- e.UnlockUser()
-
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(*v)
@@ -1838,9 +1837,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
-
case *tcpip.TCPUserTimeoutOption:
e.LockUser()
e.userTimeout = time.Duration(*v)
@@ -1909,11 +1905,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.LockUser()
- e.linger = *v
- e.UnlockUser()
-
default:
return nil
}
@@ -2014,11 +2005,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case *tcpip.BindToDeviceOption:
- e.LockUser()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.UnlockUser()
-
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.LockUser()
@@ -2046,10 +2032,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
e.UnlockUser()
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
- *o = 1
-
case *tcpip.CongestionControlOption:
e.LockUser()
*o = e.cc
@@ -2078,11 +2060,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
Port: port,
}
- case *tcpip.LingerOption:
- e.LockUser()
- *o = e.linger
- e.UnlockUser()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2230,11 +2207,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
}
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
if err != tcpip.ErrPortInUse || !reuse {
return false, nil
}
@@ -2272,15 +2250,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
return false, nil
}
}
id := e.ID
id.LocalPort = p
- if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
if err == tcpip.ErrPortInUse {
return false, nil
}
@@ -2291,7 +2269,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// the selected port.
e.ID = id
e.isPortReserved = true
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
e.boundDest = addr
return true, nil
@@ -2302,7 +2280,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.isRegistered = true
e.setEndpointState(StateConnecting)
- e.route = r.Clone()
+ r.Acquire()
+ e.route = r
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -2643,7 +2622,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
id := e.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2654,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// demuxer. Further connected endpoints always have a remote
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
- if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
return false
}
return true
@@ -2663,7 +2643,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
// TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
e.boundNICID = nic
@@ -2727,6 +2707,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
return true
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ // Linux passes the payload with the TCP header. We don't know if the TCP
+ // header even exists, it may not for fragmented packets.
+ Payload: pkt.Data.ToView(),
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.notifyProtocolGoroutine(notifyError)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
@@ -2741,16 +2756,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.notifyProtocolGoroutine(notifyMTUChanged)
case stack.ControlNoRoute:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNoRoute
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
case stack.ControlNetworkUnreachable:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNetworkUnreachable
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
}
@@ -3008,6 +3017,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
Ssthresh: e.snd.sndSsthresh,
SndCAAckCount: e.snd.sndCAAckCount,
Outstanding: e.snd.outstanding,
+ SackedOut: e.snd.sackedOut,
SndWnd: e.snd.sndWnd,
SndUna: e.snd.sndUna,
SndNxt: e.snd.sndNxt,
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index f2b1b68da..405a6dce7 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// If we started off with a window larger than what can he held in
// the 16bit window field, we ceil the value to the max value.
- // While ceiling, we still do not want to grow the right edge when
- // not applicable.
if scaledWnd > math.MaxUint16 {
- if toGrow {
- scaledWnd = seqnum.Size(math.MaxUint16)
- } else {
- scaledWnd = seqnum.Size(uint16(scaledWnd))
- }
+ scaledWnd = seqnum.Size(math.MaxUint16)
+
+ // Ensure that the stashed receive window always reflects what
+ // is being advertised.
+ r.rcvWnd = scaledWnd << r.rcvWndScale
}
return r.rcvNxt, scaledWnd
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index baec762e1..cc991aba6 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -137,6 +137,9 @@ type sender struct {
// that have been sent but not yet acknowledged.
outstanding int
+ // sackedOut is the number of packets which are selectively acked.
+ sackedOut int
+
// sndWnd is the send window size.
sndWnd seqnum.Size
@@ -372,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
+ oldMSS := s.maxPayloadSize
s.maxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
@@ -394,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
+ nextSeg := s.writeNext
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
@@ -401,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
break
}
- if seg.data.Size() > m {
+ if nextSeg == s.writeNext && seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
- s.writeNext = seg
- break
+ nextSeg = seg
+ }
+
+ if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Update sackedOut for new maximum payload size.
+ s.sackedOut -= s.pCount(seg, oldMSS)
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
+ s.writeNext = nextSeg
s.sendData()
}
@@ -629,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool {
// pCount returns the number of packets in the segment. Due to GSO, a segment
// can be composed of multiple packets.
-func (s *sender) pCount(seg *segment) int {
+func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
size := seg.data.Size()
if size == 0 {
return 1
}
- return (size-1)/s.maxPayloadSize + 1
+ return (size-1)/maxPayloadSize + 1
}
// splitSeg splits a given segment at the size specified and inserts the
@@ -1023,7 +1034,7 @@ func (s *sender) sendData() {
break
}
dataSent = true
- s.outstanding += s.pCount(seg)
+ s.outstanding += s.pCount(seg, s.maxPayloadSize)
s.writeNext = seg.Next()
}
@@ -1038,6 +1049,7 @@ func (s *sender) enterRecovery() {
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
+ s.sackedOut = 0
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
@@ -1207,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
s.rc.detectReorder(seg)
seg.acked = true
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
seg = seg.Next()
}
@@ -1380,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg)
+ prevCount := s.pCount(seg, s.maxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg)
+ s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
break
}
@@ -1399,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.writeList.Remove(seg)
- // If SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg)
+ s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ } else {
+ s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ef7f5719f..faf0c0ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) {
expected++
}
}
+
+// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK
+// is received.
+func TestSACKUpdateSackedOut(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ ackNum := 0
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.SackedOut is what we expect.
+ if state.Sender.SackedOut != 2 && ackNum == 0 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
+ }
+
+ if state.Sender.SackedOut != 0 && ackNum == 1 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
+ }
+ if ackNum > 0 {
+ close(probeDone)
+ }
+ ackNum++
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ sendAndReceive(t, c, 8)
+
+ // ACK for [3-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
+ bytesRead := 2 * maxPayload
+ end := start.Add(seqnum.Size(bytesRead))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ bytesRead += 3 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 1759ebea9..cf60d5b53 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1380,9 +1380,8 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.Cleanup()
c.Create(-1)
- bindToDevice := tcpip.BindToDeviceOption(test.device)
- if err := c.EP.SetSockOpt(&bindToDevice); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err)
+ if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err)
}
// Start connection attempt.
waitEntry, _ := waiter.NewChannelEntry(nil)
@@ -1932,6 +1931,84 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+// Test the stack receive window advertisement on receiving segments smaller than
+// segment overhead. It tests for the right edge of the window to not grow when
+// the endpoint is not being read from.
+func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize,
+ Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
+ }
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+
+ c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Bump up the receive buffer size such that, when the receive window grows,
+ // the scaled window exceeds maxUint16.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err)
+ }
+
+ // Keep the payload size < segment overhead and such that it is a multiple
+ // of the window scaled value. This enables the test to perform equality
+ // checks on the incoming receive window.
+ payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale))
+ payloadLen := seqnum.Size(len(payload))
+ iss := seqnum.Value(789)
+ seqNum := iss.Add(1)
+
+ // Send payload to the endpoint and return the advertised receive window
+ // from the endpoint.
+ getIncomingRcvWnd := func() uint32 {
+ c.SendPacket(payload, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ Flags: header.TCPFlagAck,
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(payloadLen)
+
+ pkt := c.GetPacket()
+ return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
+ }
+
+ // Read the advertised receive window with the ACK for payload.
+ rcvWnd := getIncomingRcvWnd()
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Read the data so that the subsequent ACK from the endpoint
+ // grows the right edge of the window.
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("got Read(nil) = %s", err)
+ }
+
+ // Check if we have received max uint16 as our advertised
+ // scaled window now after a read above.
+ maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
+ if got, want := getIncomingRcvWnd(), maxRcv; got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+}
+
func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -4148,7 +4225,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Check that peek works.
peekBuf := make([]byte, 10)
- n, _, err := c.EP.Peek([][]byte{peekBuf})
+ n, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
@@ -4174,7 +4251,7 @@ func TestReadAfterClosedState(t *testing.T) {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
- if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
+ if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -4429,7 +4506,7 @@ func TestBindToDeviceOption(t *testing.T) {
name string
setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
+ getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
{"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
@@ -4439,15 +4516,13 @@ func TestBindToDeviceOption(t *testing.T) {
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ bindToDevice := int32(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
- } else if bindToDevice != testAction.getBindToDevice {
+ bindToDevice := ep.SocketOptions().GetBindToDevice()
+ if bindToDevice != testAction.getBindToDevice {
t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
}
})
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 010a23e45..ee55f030c 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -635,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- NextHeader: uint8(tcp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
+ TransportProtocol: tcp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 5043e7aa5..9b9e4deb0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -30,10 +30,11 @@ import (
// +stateify savable
type udpPacket struct {
udpPacketEntry
- senderAddress tcpip.FullAddress
- packetInfo tcpip.IPPacketInfo
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ senderAddress tcpip.FullAddress
+ destinationAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
@@ -108,7 +109,6 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
portFlags ports.Flags
- bindToDevice tcpip.NICID
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -143,9 +143,6 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -228,6 +225,13 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+}
+
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
@@ -323,6 +327,10 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
}
+ if e.ops.GetReceiveOriginalDstAddress() {
+ cm.HasOriginalDstAddress = true
+ cm.OriginalDstAddress = p.destinationAddress
+ }
return p.data.ToView(), cm, nil
}
@@ -509,6 +517,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
+ so := e.SocketOptions()
+ if so.GetRecvError() {
+ so.QueueLocalErr(
+ tcpip.ErrMessageTooLong,
+ route.NetProto,
+ header.UDPMaximumPacketSize,
+ tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ Port: dstPort,
+ },
+ v,
+ )
+ }
return 0, nil, tcpip.ErrMessageTooLong
}
@@ -545,8 +567,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
@@ -636,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
@@ -752,22 +778,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
delete(e.multicastMemberships, memToRemove)
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.mu.Lock()
- e.bindToDevice = id
- e.mu.Unlock()
-
case *tcpip.SocketDetachFilterOption:
return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
}
return nil
}
@@ -841,16 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
e.mu.Unlock()
- case *tcpip.BindToDeviceOption:
- e.mu.RLock()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
-
- case *tcpip.LingerOption:
- e.mu.RLock()
- *o = e.linger
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1004,7 +1006,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: e.ID.LocalAddress,
@@ -1032,6 +1033,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
@@ -1042,7 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
- e.route = r.Clone()
+ e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1100,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
id.LocalPort = port
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@@ -1311,6 +1314,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
Addr: id.RemoteAddress,
Port: header.UDP(hdr).SourcePort(),
},
+ destinationAddress: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: header.UDP(hdr).DestinationPort(),
+ },
}
packet.data = pkt.Data
e.rcvList.PushBack(packet)
@@ -1341,15 +1349,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ // Linux passes the payload without the UDP header.
+ var payload []byte
+ udp := header.UDP(pkt.Data.ToView())
+ if len(udp) >= header.UDPMinimumSize {
+ payload = udp.Payload()
+ }
+
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ Payload: payload,
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.waiterQueue.Notify(waiter.EventErr)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
if e.EndpointState() == StateConnected {
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrConnectionRefused
- e.lastErrorMu.Unlock()
-
- e.waiterQueue.Notify(waiter.EventErr)
+ var errType byte
+ var errCode byte
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ errType = byte(header.ICMPv4DstUnreachable)
+ errCode = byte(header.ICMPv4PortUnreachable)
+ case header.IPv6ProtocolNumber:
+ errType = byte(header.ICMPv6DstUnreachable)
+ errCode = byte(header.ICMPv6PortUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
+ }
+ e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt)
return
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 14e4648cd..d7fc21f11 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
route.ResolveWith(r.pkt.SourceLinkAddress())
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index e384f52dd..8429f34b4 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -452,12 +452,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -554,7 +554,7 @@ func TestBindToDeviceOption(t *testing.T) {
name string
setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
+ getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
{"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
@@ -564,15 +564,13 @@ func TestBindToDeviceOption(t *testing.T) {
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ bindToDevice := int32(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
- } else if bindToDevice != testAction.getBindToDevice {
+ bindToDevice := ep.SocketOptions().GetBindToDevice()
+ if bindToDevice != testAction.getBindToDevice {
t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
}
})
@@ -1427,6 +1425,93 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
+func TestReadRecvOriginalDstAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedOriginalDstAddr tcpip.FullAddress
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%#v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
+ }
+ }
+
+ c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
+
+ testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1994,12 +2079,12 @@ func TestShortHeader(t *testing.T) {
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.