summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD15
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go57
-rw-r--r--pkg/tcpip/buffer/view.go37
-rw-r--r--pkg/tcpip/buffer/view_test.go68
-rw-r--r--pkg/tcpip/checker/checker.go414
-rw-r--r--pkg/tcpip/header/BUILD5
-rw-r--r--pkg/tcpip/header/checksum_test.go94
-rw-r--r--pkg/tcpip/header/icmpv4.go30
-rw-r--r--pkg/tcpip/header/icmpv6.go40
-rw-r--r--pkg/tcpip/header/igmp.go181
-rw-r--r--pkg/tcpip/header/igmp_test.go110
-rw-r--r--pkg/tcpip/header/ipv4.go222
-rw-r--r--pkg/tcpip/header/ipv4_test.go179
-rw-r--r--pkg/tcpip/header/ipv6.go62
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go345
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go356
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go42
-rw-r--r--pkg/tcpip/header/ipv6_test.go44
-rw-r--r--pkg/tcpip/header/mld.go103
-rw-r--r--pkg/tcpip/header/mld_test.go61
-rw-r--r--pkg/tcpip/header/ndp_options.go2
-rw-r--r--pkg/tcpip/header/parse/parse.go3
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go30
-rw-r--r--pkg/tcpip/link/ethernet/BUILD16
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go10
-rw-r--r--pkg/tcpip/link/ethernet/ethernet_test.go71
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go18
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go18
-rw-r--r--pkg/tcpip/link/loopback/loopback.go17
-rw-r--r--pkg/tcpip/link/muxed/BUILD1
-rw-r--r--pkg/tcpip/link/muxed/injectable.go8
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go6
-rw-r--r--pkg/tcpip/link/nested/BUILD1
-rw-r--r--pkg/tcpip/link/nested/nested.go6
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go4
-rw-r--r--pkg/tcpip/link/pipe/pipe.go7
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD1
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go16
-rw-r--r--pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go1
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go17
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go32
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go66
-rw-r--r--pkg/tcpip/link/tun/device.go40
-rw-r--r--pkg/tcpip/link/waitable/BUILD2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go6
-rw-r--r--pkg/tcpip/network/BUILD3
-rw-r--r--pkg/tcpip/network/arp/arp.go34
-rw-r--r--pkg/tcpip/network/arp/arp_test.go18
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD3
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go77
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go81
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go207
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go194
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go246
-rw-r--r--pkg/tcpip/network/ip/BUILD26
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go676
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go806
-rw-r--r--pkg/tcpip/network/ip_test.go424
-rw-r--r--pkg/tcpip/network/ipv4/BUILD7
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go58
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go345
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go215
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go377
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go333
-rw-r--r--pkg/tcpip/network/ipv6/BUILD18
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go100
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go254
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go560
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go492
-rw-r--r--pkg/tcpip/network/ipv6/mld.go262
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go297
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go351
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go168
-rw-r--r--pkg/tcpip/network/multicast_group_test.go1261
-rw-r--r--pkg/tcpip/network/testutil/testutil.go15
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go11
-rw-r--r--pkg/tcpip/socketops.go520
-rw-r--r--pkg/tcpip/stack/BUILD7
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go179
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go22
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go90
-rw-r--r--pkg/tcpip/stack/iptables.go66
-rw-r--r--pkg/tcpip/stack/iptables_types.go10
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go135
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go110
-rw-r--r--pkg/tcpip/stack/ndp_test.go255
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go107
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go498
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go118
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go347
-rw-r--r--pkg/tcpip/stack/nic.go283
-rw-r--r--pkg/tcpip/stack/nud.go21
-rw-r--r--pkg/tcpip/stack/nud_test.go53
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go68
-rw-r--r--pkg/tcpip/stack/route.go310
-rw-r--r--pkg/tcpip/stack/stack.go178
-rw-r--r--pkg/tcpip/stack/stack_test.go489
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go5
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go22
-rw-r--r--pkg/tcpip/stack/transport_test.go137
-rw-r--r--pkg/tcpip/tcpip.go357
-rw-r--r--pkg/tcpip/tcpip_test.go84
-rw-r--r--pkg/tcpip/tests/integration/BUILD3
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go102
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go56
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go218
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go547
-rw-r--r--pkg/tcpip/tests/integration/route_test.go69
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go129
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go101
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go167
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD7
-rw-r--r--pkg/tcpip/transport/tcp/accept.go151
-rw-r--r--pkg/tcpip/transport/tcp/connect.go179
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go18
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go751
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go11
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go7
-rw-r--r--pkg/tcpip/transport/tcp/rack.go132
-rw-r--r--pkg/tcpip/transport/tcp/rack_state.go5
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go50
-rw-r--r--pkg/tcpip/transport/tcp/reno_recovery.go67
-rw-r--r--pkg/tcpip/transport/tcp/sack_recovery.go120
-rw-r--r--pkg/tcpip/transport/tcp/segment.go16
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go13
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go441
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go415
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go42
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go491
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go23
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go14
-rw-r--r--pkg/tcpip/transport/tcp/timer.go4
-rw-r--r--pkg/tcpip/transport/udp/BUILD2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go554
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go9
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go293
146 files changed, 14700 insertions, 5936 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 454e07662..89b765f1b 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,10 +1,25 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "sock_err_list",
+ out = "sock_err_list.go",
+ package = "tcpip",
+ prefix = "sockError",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*SockError",
+ "Linker": "*SockError",
+ },
+)
+
go_library(
name = "tcpip",
srcs = [
+ "sock_err_list.go",
+ "socketops.go",
"tcpip.go",
"time_unsafe.go",
"timer.go",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 4f551cd92..7193f56ad 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -286,45 +286,47 @@ type opErrorer interface {
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
-func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
+func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) {
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
default:
}
- read, _, err := ep.Read(addr)
+ w := tcpip.SliceWriter(b)
+ opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
+ res, err := ep.Read(&w, len(b), opts)
if err == tcpip.ErrWouldBlock {
- if dontWait {
- return nil, errWouldBlock
- }
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- read, _, err = ep.Read(addr)
+ res, err = ep.Read(&w, len(b), opts)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
case <-notifyCh:
}
}
}
if err == tcpip.ErrClosedForReceive {
- return nil, io.EOF
+ return 0, io.EOF
}
if err != nil {
- return nil, errorer.newOpError("read", errors.New(err.String()))
+ return 0, errorer.newOpError("read", errors.New(err.String()))
}
- return read, nil
+ if addr != nil {
+ *addr = res.RemoteAddr
+ }
+ return res.Count, nil
}
// Read implements net.Conn.Read.
@@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) {
deadline := c.readCancel()
- numRead := 0
- defer func() {
- if numRead != 0 {
- c.ep.ModerateRecvBuf(numRead)
- }
- }()
- for numRead != len(b) {
- if len(c.read) == 0 {
- var err error
- c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
- if err != nil {
- if numRead != 0 {
- return numRead, nil
- }
- return numRead, err
- }
- }
- n := copy(b[numRead:], c.read)
- c.read.TrimFront(n)
- numRead += n
- if len(c.read) == 0 {
- c.read = nil
- }
+ n, err := commonRead(b, c.ep, c.wq, deadline, nil, c)
+ if n != 0 {
+ c.ep.ModerateRecvBuf(n)
}
- return numRead, nil
+ return n, err
}
// Write implements net.Conn.Write.
@@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
- read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
+ n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c)
if err != nil {
return 0, nil, err
}
-
- return copy(b, read), fullToUDPAddr(addr), nil
+ return n, fullToUDPAddr(addr), nil
}
func (c *UDPConn) Write(b []byte) (int, error) {
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 8db70a700..5dd1b1b6b 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) {
}
// Read implements io.Reader.
-func (vv *VectorisedView) Read(v View) (copied int, err error) {
- count := len(v)
+func (vv *VectorisedView) Read(b []byte) (copied int, err error) {
+ count := len(b)
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
vv.size -= count
- copy(v[copied:], vv.views[0][:count])
+ copy(b[copied:], vv.views[0][:count])
vv.views[0].TrimFront(count)
copied += count
return copied, nil
}
count -= len(vv.views[0])
- copy(v[copied:], vv.views[0])
+ copy(b[copied:], vv.views[0])
copied += len(vv.views[0])
vv.removeFirst()
}
@@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
return copied
}
+// ReadTo reads up to count bytes from vv to dst. It also removes them from vv
+// unless peek is true.
+func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) {
+ var err error
+ done := 0
+ for _, v := range vv.Views() {
+ remaining := count - done
+ if remaining <= 0 {
+ break
+ }
+ if len(v) > remaining {
+ v = v[:remaining]
+ }
+
+ var n int
+ n, err = dst.Write(v)
+ if n > 0 {
+ done += n
+ }
+ if err != nil {
+ break
+ }
+ }
+ if !peek {
+ vv.TrimFront(done)
+ }
+ return done, err
+}
+
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index 726e54de9..e0ef8a94d 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -235,14 +235,16 @@ func TestToClone(t *testing.T) {
}
}
-func TestVVReadToVV(t *testing.T) {
- testCases := []struct {
- comment string
- vv VectorisedView
- bytesToRead int
- wantBytes string
- leftVV VectorisedView
- }{
+type readToTestCases struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ wantBytes string
+ leftVV VectorisedView
+}
+
+func createReadToTestCases() []readToTestCases {
+ return []readToTestCases{
{
comment: "large VV, short read",
vv: vv(30, "012345678901234567890123456789"),
@@ -279,8 +281,10 @@ func TestVVReadToVV(t *testing.T) {
leftVV: vv(0, ""),
},
}
+}
- for _, tc := range testCases {
+func TestVVReadToVV(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
t.Run(tc.comment, func(t *testing.T) {
var readTo VectorisedView
inSize := tc.vv.Size()
@@ -301,6 +305,52 @@ func TestVVReadToVV(t *testing.T) {
}
}
+func TestVVReadTo(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
+ t.Run(tc.comment, func(t *testing.T) {
+ var dst bytes.Buffer
+ origSize := tc.vv.Size()
+ copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, false /* peek */)
+ if got, want := copied, len(tc.wantBytes); err != nil || got != want {
+ t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ }
+ if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ t.Errorf("got dst = %q, want %q", got, want)
+ }
+ if got, want := tc.vv.Size(), origSize-copied; got != want {
+ t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
+ }
+ if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
+ t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
+ }
+ })
+ }
+}
+
+func TestVVReadToPeek(t *testing.T) {
+ for _, tc := range createReadToTestCases() {
+ t.Run(tc.comment, func(t *testing.T) {
+ var dst bytes.Buffer
+ origSize := tc.vv.Size()
+ origData := string(tc.vv.ToView())
+ copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, true /* peek */)
+ if got, want := copied, len(tc.wantBytes); err != nil || got != want {
+ t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ }
+ if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ t.Errorf("got dst = %q, want %q", got, want)
+ }
+ // Expect tc.vv is unchanged.
+ if got, want := tc.vv.Size(), origSize; got != want {
+ t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
+ }
+ if got, want := string(tc.vv.ToView()), origData; got != want {
+ t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func TestVVRead(t *testing.T) {
testCases := []struct {
comment string
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 8868cf4e3..0ac2000ca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -20,6 +20,7 @@ import (
"encoding/binary"
"reflect"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
+ case *ipv6HeaderWithExtHdr:
+ v = ip.HopLimit()
+ default:
+ t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
}
if v != ttl {
t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
@@ -216,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker {
}
}
+// IPv4RouterAlert returns a checker that checks that the RouterAlert option is
+// set in an IPv4 packet.
+func IPv4RouterAlert() NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ iterator := ip.Options().MakeIterator()
+ for {
+ opt, done, err := iterator.Next()
+ if err != nil {
+ t.Fatalf("error acquiring next IPv4 option %s", err)
+ }
+ if done {
+ break
+ }
+ if opt.Type() != header.IPv4OptionRouterAlertType {
+ continue
+ }
+ want := [header.IPv4OptionRouterAlertLength]byte{
+ byte(header.IPv4OptionRouterAlertType),
+ header.IPv4OptionRouterAlertLength,
+ header.IPv4OptionRouterAlertValue,
+ header.IPv4OptionRouterAlertValue,
+ }
+ if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
+ t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
+ }
+ return
+ }
+ t.Errorf("failed to find router alert option in %v", ip.Options())
+ }
+}
+
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
+// field in ControlMessages.
+func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasOriginalDstAddress {
+ t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
+ } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
+ t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -904,6 +958,12 @@ func ICMPv4Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
payload := icmpv4.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
@@ -994,12 +1054,86 @@ func ICMPv6Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
payload := icmpv6.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
+// MLD creates a checker that checks that the packet contains a valid MLD
+// message for type of mldType, with potentially additional checks specified by
+// checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// MLD message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.MessageBody()); got < minSize {
+ t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMaxRespDelay(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MaximumResponseDelay(); got != want {
+ t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// MLDMulticastAddress creates a checker that checks the Multicast Address
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMulticastAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MulticastAddress(); got != want {
+ t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
@@ -1019,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
last := h[len(h)-1]
icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.NDPPayload()); got < minSize {
+ if got := len(icmp.MessageBody()); got < minSize {
t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
}
@@ -1053,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
if got := ns.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
@@ -1082,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
@@ -1100,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.SolicitedFlag(); got != want {
t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
@@ -1171,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
ndpOptions(t, na.Options(), opts)
}
}
@@ -1186,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ndpOptions(t, ns.Options(), opts)
}
}
@@ -1211,7 +1345,273 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ rs := header.NDPRouterSolicit(icmp.MessageBody())
ndpOptions(t, rs.Options(), opts)
}
}
+
+// IGMP checks the validity and properties of the given IGMP packet. It is
+// expected to be used in conjunction with other IGMP transport checkers for
+// specific properties.
+func IGMP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
+ }
+
+ igmp := header.IGMP(last.Payload())
+ for _, f := range checkers {
+ f(t, igmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// IGMPType creates a checker that checks the IGMP Type field.
+func IGMPType(want header.IGMPType) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.Type(); got != want {
+ t.Errorf("got igmp.Type() = %d, want = %d", got, want)
+ }
+ }
+}
+
+// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
+func IGMPMaxRespTime(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.MaxRespTime(); got != want {
+ t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
+func IGMPGroupAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.GroupAddress(); got != want {
+ t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IPv6ExtHdrChecker is a function to check an extension header.
+type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
+
+// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
+func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Error("not a valid IPv6 packet")
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var rawPayloadHeader header.IPv6RawPayloadHeader
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
+ return
+ }
+ r, ok := h.(header.IPv6RawPayloadHeader)
+ if ok {
+ rawPayloadHeader = r
+ break
+ }
+ }
+
+ networkHeader := ipv6HeaderWithExtHdr{
+ IPv6: ipv6,
+ transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
+ payload: rawPayloadHeader.Buf.ToView(),
+ }
+
+ for _, checker := range checkers {
+ checker(t, []header.Network{&networkHeader})
+ }
+}
+
+// IPv6ExtHdr checks for the presence of extension headers.
+//
+// All the extension headers in headers will be checked exhaustively in the
+// order provided.
+func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
+ if !ok {
+ t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
+ buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
+ )
+
+ for _, check := range headers {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ return
+ }
+ check(t, h)
+ }
+ // Validate we consumed all headers.
+ //
+ // The next one over should be a raw payload and then iterator should
+ // terminate.
+ wantDone := false
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done != wantDone {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
+ return
+ }
+ if done {
+ break
+ }
+ if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
+ t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
+ continue
+ }
+ wantDone = true
+ }
+ }
+}
+
+var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
+
+// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
+// extension headers into consideration, which is not the case with vanilla
+// header.IPv6.
+type ipv6HeaderWithExtHdr struct {
+ header.IPv6
+ transport tcpip.TransportProtocolNumber
+ payload []byte
+}
+
+// TransportProtocol implements header.Network.
+func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
+ return h.transport
+}
+
+// Payload implements header.Network.
+func (h *ipv6HeaderWithExtHdr) Payload() []byte {
+ return h.payload
+}
+
+// IPv6ExtHdrOptionChecker is a function to check an extension header option.
+type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
+
+// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
+// extension header and validates the containing options with checkers.
+//
+// checkers must exhaustively contain all the expected options.
+func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
+ return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
+ t.Helper()
+
+ hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
+ if !ok {
+ t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
+ return
+ }
+ optionsIterator := hbh.Iter()
+ for _, f := range checkers {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ f(t, opt)
+ }
+ // Validate all options were consumed.
+ for {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if !done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ if done {
+ break
+ }
+ }
+ }
+}
+
+// IPv6RouterAlert validates that an extension header option is the RouterAlert
+// option and matches on its value.
+func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
+ if !ok {
+ t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
+ return
+ }
+ if routerAlert.Value != want {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
+ }
+ }
+}
+
+// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
+func IgnoreCmpPath(paths ...string) cmp.Option {
+ ignores := map[string]struct{}{}
+ for _, path := range paths {
+ ignores[path] = struct{}{}
+ }
+ return cmp.FilterPath(func(path cmp.Path) bool {
+ _, ok := ignores[path.String()]
+ return ok
+ }, cmp.Ignore())
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index d87797617..0bdc12d53 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -11,11 +11,13 @@ go_library(
"gue.go",
"icmpv4.go",
"icmpv6.go",
+ "igmp.go",
"interfaces.go",
"ipv4.go",
"ipv6.go",
"ipv6_extension_headers.go",
"ipv6_fragment.go",
+ "mld.go",
"ndp_neighbor_advert.go",
"ndp_neighbor_solicit.go",
"ndp_options.go",
@@ -39,6 +41,8 @@ go_test(
size = "small",
srcs = [
"checksum_test.go",
+ "igmp_test.go",
+ "ipv4_test.go",
"ipv6_test.go",
"ipversion_test.go",
"tcp_test.go",
@@ -58,6 +62,7 @@ go_test(
srcs = [
"eth_test.go",
"ipv6_extension_headers_test.go",
+ "mld_test.go",
"ndp_test.go",
],
library = ":header",
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 309403482..5ab20ee86 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -19,6 +19,7 @@ package header_test
import (
"fmt"
"math/rand"
+ "sync"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) {
}
}
}
+
+func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) {
+ // icmpChecksum should not do any modifications of the header to
+ // calculate its checksum. Let's call it from a few go-routines and the
+ // race detector will trigger a warning if there are any concurrent
+ // read/write accesses.
+
+ const concurrency = 5
+ start := make(chan int)
+ ready := make(chan bool, concurrency)
+ var wg sync.WaitGroup
+ wg.Add(concurrency)
+ defer wg.Wait()
+
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ defer wg.Done()
+
+ ready <- true
+ <-start
+
+ if got := headerChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ if got := icmpChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ }()
+ }
+ for i := 0; i < concurrency; i++ {
+ <-ready
+ }
+ close(start)
+}
+
+func TestICMPv4Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:5]),
+ buffer.NewViewFromBytes(buf[5:]),
+ })
+
+ want := header.Checksum(vv.ToView(), 0)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv4Checksum(h, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}
+
+func TestICMPv6Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:7]),
+ buffer.NewViewFromBytes(buf[7:10]),
+ buffer.NewViewFromBytes(buf[10:]),
+ })
+
+ dst := header.IPv6Loopback
+ src := header.IPv6Loopback
+
+ want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+ want = header.Checksum(vv.ToView(), want)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv6Checksum(h, src, dst, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 2f13dea6a..5f9b8e9e2 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -199,17 +200,24 @@ func (b ICMPv4) SetSequence(sequence uint16) {
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := uint16(0)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
+ xsum := ChecksumVV(vv, 0)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
+ return ^xsum
+}
- return xsum
+// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when
+// a packet having a `net` header causing an ICMP error.
+func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin {
+ switch net {
+ case IPv4ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP
+ case IPv6ProtocolNumber:
+ return tcpip.SockExtErrorOriginICMP6
+ default:
+ panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net))
+ }
}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 4303fc5d5..eca9750ab 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -115,6 +115,12 @@ const (
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
+
+ // Multicast Listener Discovery (MLD) messages, see RFC 2710.
+
+ ICMPv6MulticastListenerQuery ICMPv6Type = 130
+ ICMPv6MulticastListenerReport ICMPv6Type = 131
+ ICMPv6MulticastListenerDone ICMPv6Type = 132
)
// IsErrorType returns true if the receiver is an ICMP error type.
@@ -245,10 +251,9 @@ func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
-// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6
-// packet's message body as defined by RFC 4443 section 2.1; the portion of the
-// ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
-func (b ICMPv6) NDPPayload() []byte {
+// MessageBody returns the message body as defined by RFC 4443 section 2.1; the
+// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
+func (b ICMPv6) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
@@ -260,22 +265,13 @@ func (b ICMPv6) Payload() []byte {
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := Checksum([]byte(src), 0)
- xsum = Checksum([]byte(dst), xsum)
- var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
- xsum = Checksum(upperLayerLength[:], xsum)
- xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
- for _, v := range vv.Views() {
- xsum = Checksum(v, xsum)
- }
-
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^Checksum(h, xsum)
- h[2], h[3] = h2, h3
-
- return xsum
+ xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+
+ xsum = ChecksumVV(vv, xsum)
+
+ // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
+ xsum = Checksum(h[:2], xsum)
+ xsum = Checksum(h[4:], xsum)
+
+ return ^xsum
}
diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go
new file mode 100644
index 000000000..5c5be1b9d
--- /dev/null
+++ b/pkg/tcpip/header/igmp.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// IGMP represents an IGMP header stored in a byte array.
+type IGMP []byte
+
+// IGMP implements `Transport`.
+var _ Transport = (*IGMP)(nil)
+
+const (
+ // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes,
+ // as per RFC 2236, Section 2, Page 2.
+ IGMPMinimumSize = 8
+
+ // IGMPQueryMinimumSize is the minimum size of a valid Membership Query
+ // Message in bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPQueryMinimumSize = 8
+
+ // IGMPReportMinimumSize is the minimum size of a valid Report Message in
+ // bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPReportMinimumSize = 8
+
+ // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message
+ // in bytes, as per RFC 2236, Section 2, Page 2.
+ IGMPLeaveMessageMinimumSize = 8
+
+ // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page
+ // 3.
+ IGMPTTL = 1
+
+ // igmpTypeOffset defines the offset of the type field in an IGMP message.
+ igmpTypeOffset = 0
+
+ // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an
+ // IGMP message.
+ igmpMaxRespTimeOffset = 1
+
+ // igmpChecksumOffset defines the offset of the checksum field in an IGMP
+ // message.
+ igmpChecksumOffset = 2
+
+ // igmpGroupAddressOffset defines the offset of the Group Address field in an
+ // IGMP message.
+ igmpGroupAddressOffset = 4
+
+ // IGMPProtocolNumber is IGMP's transport protocol number.
+ IGMPProtocolNumber tcpip.TransportProtocolNumber = 2
+)
+
+// IGMPType is the IGMP type field as per RFC 2236.
+type IGMPType byte
+
+// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2.
+// Descriptions below come from there.
+const (
+ // IGMPMembershipQuery indicates that the message type is Membership Query.
+ // "There are two sub-types of Membership Query messages:
+ // - General Query, used to learn which groups have members on an
+ // attached network.
+ // - Group-Specific Query, used to learn if a particular group
+ // has any members on an attached network.
+ // These two messages are differentiated by the Group Address, as
+ // described in section 1.4 ."
+ IGMPMembershipQuery IGMPType = 0x11
+ // IGMPv1MembershipReport indicates that the message is a Membership Report
+ // generated by a host using the IGMPv1 protocol: "an additional type of
+ // message, for backwards-compatibility with IGMPv1"
+ IGMPv1MembershipReport IGMPType = 0x12
+ // IGMPv2MembershipReport indicates that the Message type is a Membership
+ // Report generated by a host using the IGMPv2 protocol.
+ IGMPv2MembershipReport IGMPType = 0x16
+ // IGMPLeaveGroup indicates that the message type is a Leave Group
+ // notification message.
+ IGMPLeaveGroup IGMPType = 0x17
+)
+
+// Type is the IGMP type field.
+func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) }
+
+// SetType sets the IGMP type field.
+func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) }
+
+// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership
+// Query messages, in other cases it is set to 0 by the sender and ignored by
+// the receiver.
+func (b IGMP) MaxRespTime() time.Duration {
+ // As per RFC 2236 section 2.2,
+ //
+ // The Max Response Time field is meaningful only in Membership Query
+ // messages, and specifies the maximum allowed time before sending a
+ // responding report in units of 1/10 second. In all other messages, it
+ // is set to zero by the sender and ignored by receivers.
+ return DecisecondToDuration(b[igmpMaxRespTimeOffset])
+}
+
+// SetMaxRespTime sets the MaxRespTimeField.
+func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m }
+
+// Checksum is the IGMP checksum field.
+func (b IGMP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[igmpChecksumOffset:])
+}
+
+// SetChecksum sets the IGMP checksum field.
+func (b IGMP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum)
+}
+
+// GroupAddress gets the Group Address field.
+func (b IGMP) GroupAddress() tcpip.Address {
+ return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize])
+}
+
+// SetGroupAddress sets the Group Address field.
+func (b IGMP) SetGroupAddress(address tcpip.Address) {
+ if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize))
+ }
+}
+
+// SourcePort implements Transport.SourcePort.
+func (IGMP) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (IGMP) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (IGMP) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (IGMP) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (IGMP) Payload() []byte {
+ return nil
+}
+
+// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP
+// header.
+func IGMPCalculateChecksum(h IGMP) uint16 {
+ // The header contains a checksum itself, set it aside to avoid checksumming
+ // the checksum and replace it afterwards.
+ existingXsum := h.Checksum()
+ h.SetChecksum(0)
+ xsum := ^Checksum(h, 0)
+ h.SetChecksum(existingXsum)
+ return xsum
+}
+
+// DecisecondToDuration converts a value representing deci-seconds to a
+// time.Duration.
+func DecisecondToDuration(ds uint8) time.Duration {
+ return time.Duration(ds) * time.Second / 10
+}
diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go
new file mode 100644
index 000000000..b6126d29a
--- /dev/null
+++ b/pkg/tcpip/header/igmp_test.go
@@ -0,0 +1,110 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// TestIGMPHeader tests the functions within header.igmp
+func TestIGMPHeader(t *testing.T) {
+ const maxRespTimeTenthSec = 0xF0
+ b := []byte{
+ 0x11, // IGMP Type, Membership Query
+ maxRespTimeTenthSec, // Maximum Response Time
+ 0xC0, 0xC0, // Checksum
+ 0x01, 0x02, 0x03, 0x04, // Group Address
+ }
+
+ igmpHeader := header.IGMP(b)
+
+ if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want {
+ t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want)
+ }
+
+ if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want {
+ t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
+ }
+
+ if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want {
+ t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want)
+ }
+
+ if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want {
+ t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want)
+ }
+
+ igmpType := header.IGMPv2MembershipReport
+ igmpHeader.SetType(igmpType)
+ if got := igmpHeader.Type(); got != igmpType {
+ t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType)
+ }
+ if got := header.IGMPType(b[0]); got != igmpType {
+ t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType)
+ }
+
+ respTime := byte(0x02)
+ igmpHeader.SetMaxRespTime(respTime)
+ if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want {
+ t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
+ }
+
+ checksum := uint16(0x0102)
+ igmpHeader.SetChecksum(checksum)
+ if got := igmpHeader.Checksum(); got != checksum {
+ t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum)
+ }
+
+ groupAddress := tcpip.Address("\x04\x03\x02\x01")
+ igmpHeader.SetGroupAddress(groupAddress)
+ if got := igmpHeader.GroupAddress(); got != groupAddress {
+ t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress)
+ }
+}
+
+// TestIGMPChecksum ensures that the checksum calculator produces the expected
+// checksum.
+func TestIGMPChecksum(t *testing.T) {
+ b := []byte{
+ 0x11, // IGMP Type, Membership Query
+ 0xF0, // Maximum Response Time
+ 0xC0, 0xC0, // Checksum
+ 0x01, 0x02, 0x03, 0x04, // Group Address
+ }
+
+ igmpHeader := header.IGMP(b)
+
+ // Calculate the initial checksum after setting the checksum temporarily to 0
+ // to avoid checksumming the checksum.
+ initialChecksum := igmpHeader.Checksum()
+ igmpHeader.SetChecksum(0)
+ checksum := ^header.Checksum(b, 0)
+ igmpHeader.SetChecksum(initialChecksum)
+
+ if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum {
+ t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum)
+ }
+}
+
+func TestDecisecondToDuration(t *testing.T) {
+ const valueInDeciseconds = 5
+ if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want {
+ t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 7e32b31b4..e6103f4bc 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -89,8 +89,18 @@ type IPv4Fields struct {
// DstAddr is the "destination ip address" of an IPv4 packet.
DstAddr tcpip.Address
- // Options is between 0 and 40 bytes or nil if empty.
- Options IPv4Options
+ // Options must be 40 bytes or less as they must fit along with the
+ // rest of the IPv4 header into the maximum size describable in the
+ // IHL field. RFC 791 section 3.1 says:
+ // IHL: 4 bits
+ //
+ // Internet Header Length is the length of the internet header in 32
+ // bit words, and thus points to the beginning of the data. Note that
+ // the minimum value for a correct header is 5.
+ //
+ // That leaves ten 32 bit (4 byte) fields for options. An attempt to encode
+ // more will fail.
+ Options IPv4OptionsSerializer
}
// IPv4 is an IPv4 header.
@@ -147,6 +157,9 @@ const (
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+ // IPv4AllRoutersGroup is a multicast address for all routers.
+ IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02"
+
// IPv4MinimumProcessableDatagramSize is the minimum size of an IP
// packet that every IPv4 capable host must be able to
// process/reassemble.
@@ -272,25 +285,21 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
-// IPv4Options is a buffer that holds all the raw IP options.
-type IPv4Options []byte
-
-// AllocationSize implements stack.NetOptions.
-// It reports the size to allocate for the Options. RFC 791 page 23 (end of
-// section 3.1) says of the padding at the end of the options:
+// padIPv4OptionsLength returns the total length for IPv4 options of length l
+// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
// header ends on a 32 bit boundary.
-func (o IPv4Options) AllocationSize() int {
- return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1)
+func padIPv4OptionsLength(length uint8) uint8 {
+ return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1)
}
-// Options returns a buffer holding the options or nil.
+// IPv4Options is a buffer that holds all the raw IP options.
+type IPv4Options []byte
+
+// Options returns a buffer holding the options.
func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
- if hdrLen > IPv4MinimumSize {
- return IPv4Options(b[options:hdrLen:hdrLen])
- }
- return nil
+ return IPv4Options(b[options:hdrLen:hdrLen])
}
// TransportProtocol implements Network.TransportProtocol.
@@ -365,20 +374,16 @@ func (b IPv4) CalculateChecksum() uint16 {
func (b IPv4) Encode(i *IPv4Fields) {
// The size of the options defines the size of the whole header and thus the
// IHL field. Options are rare and this is a heavily used function so it is
- // worth a bit of optimisation here to keep the copy out of the fast path.
- hdrLen := IPv4MinimumSize
+ // worth a bit of optimisation here to keep the serializer out of the fast
+ // path.
+ hdrLen := uint8(IPv4MinimumSize)
if len(i.Options) != 0 {
- // AllocationSize is always >= len(i.Options).
- aLen := i.Options.AllocationSize()
- hdrLen += aLen
- if hdrLen > len(b) {
- panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen))
- }
- if aLen != copy(b[options:], i.Options) {
- _ = copy(b[options+len(i.Options):options+aLen], []byte{0, 0, 0, 0})
- }
+ hdrLen += i.Options.Serialize(b[options:])
}
- b.SetHeaderLength(uint8(hdrLen))
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize))
+ }
+ b.SetHeaderLength(hdrLen)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
@@ -458,6 +463,10 @@ const (
// options and may appear multiple times.
IPv4OptionNOPType IPv4OptionType = 1
+ // IPv4OptionRouterAlertType is the option type for the Router Alert option,
+ // defined in RFC 2113 Section 2.1.
+ IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80
+
// IPv4OptionRecordRouteType is used by each router on the path of the packet
// to record its path. It is carried over to an Echo Reply.
IPv4OptionRecordRouteType IPv4OptionType = 7
@@ -858,3 +867,162 @@ func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
// Contents implements IPv4Option.
func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
+
+// Router Alert option specific related constants.
+//
+// from RFC 2113 section 2.1:
+//
+// +--------+--------+--------+--------+
+// |10010100|00000100| 2 octet value |
+// +--------+--------+--------+--------+
+//
+// Type:
+// Copied flag: 1 (all fragments must carry the option)
+// Option class: 0 (control)
+// Option number: 20 (decimal)
+//
+// Length: 4
+//
+// Value: A two octet code with the following values:
+// 0 - Router shall examine packet
+// 1-65535 - Reserved
+const (
+ // IPv4OptionRouterAlertLength is the length of a Router Alert option.
+ IPv4OptionRouterAlertLength = 4
+
+ // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit
+ // payload of the router alert option.
+ IPv4OptionRouterAlertValue = 0
+
+ // iPv4OptionRouterAlertValueOffset is the offset for the value of a
+ // RouterAlert option.
+ iPv4OptionRouterAlertValueOffset = 2
+)
+
+// IPv4SerializableOption is an interface to represent serializable IPv4 option
+// types.
+type IPv4SerializableOption interface {
+ // optionType returns the type identifier of the option.
+ optionType() IPv4OptionType
+}
+
+// IPv4SerializableOptionPayload is an interface providing serialization of the
+// payload of an IPv4 option.
+type IPv4SerializableOptionPayload interface {
+ // length returns the size of the payload.
+ length() uint8
+
+ // serializeInto serializes the payload into the provided byte buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // Length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MUST panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto(buffer []byte) uint8
+}
+
+// IPv4OptionsSerializer is a serializer for IPv4 options.
+type IPv4OptionsSerializer []IPv4SerializableOption
+
+// Length returns the total number of bytes required to serialize the options.
+func (s IPv4OptionsSerializer) Length() uint8 {
+ var total uint8
+ for _, opt := range s {
+ total++
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Add 1 to reported length to account for the length byte.
+ total += 1 + withPayload.length()
+ }
+ }
+ return padIPv4OptionsLength(total)
+}
+
+// Serialize serializes the provided list of IPV4 options into b.
+//
+// Note, b must be of sufficient size to hold all the options in s. See
+// IPv4OptionsSerializer.Length for details on the getting the total size
+// of a serialized IPv4OptionsSerializer.
+//
+// Serialize panics if b is not of sufficient size to hold all the options in s.
+func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 {
+ var total uint8
+ for _, opt := range s {
+ ty := opt.optionType()
+ if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok {
+ // Serialize first to reduce bounds checks.
+ l := 2 + withPayload.serializeInto(b[2:])
+ b[0] = byte(ty)
+ b[1] = l
+ b = b[l:]
+ total += l
+ continue
+ }
+ // Options without payload consist only of the type field.
+ //
+ // NB: Repeating code from the branch above is intentional to minimize
+ // bounds checks.
+ b[0] = byte(ty)
+ b = b[1:]
+ total++
+ }
+
+ // According to RFC 791:
+ //
+ // The internet header padding is used to ensure that the internet
+ // header ends on a 32 bit boundary. The padding is zero.
+ padded := padIPv4OptionsLength(total)
+ b = b[:padded-total]
+ for i := range b {
+ b[i] = 0
+ }
+ return padded
+}
+
+var _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil)
+var _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil)
+
+// IPv4SerializableRouterAlertOption provides serialization of the Router Alert
+// IPv4 option according to RFC 2113.
+type IPv4SerializableRouterAlertOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType {
+ return IPv4OptionRouterAlertType
+}
+
+// Length implements IPv4SerializableOption.
+func (*IPv4SerializableRouterAlertOption) length() uint8 {
+ return IPv4OptionRouterAlertLength - iPv4OptionRouterAlertValueOffset
+}
+
+// SerializeInto implements IPv4SerializableOption.
+func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 {
+ binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue)
+ return o.length()
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil)
+
+// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option.
+type IPv4SerializableNOPOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableNOPOption) optionType() IPv4OptionType {
+ return IPv4OptionNOPType
+}
+
+var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil)
+
+// IPv4SerializableListEndOption provides serialization for the IPv4 List End
+// option.
+type IPv4SerializableListEndOption struct{}
+
+// Type implements IPv4SerializableOption.
+func (*IPv4SerializableListEndOption) optionType() IPv4OptionType {
+ return IPv4OptionListEndType
+}
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
new file mode 100644
index 000000000..6475cd694
--- /dev/null
+++ b/pkg/tcpip/header/ipv4_test.go
@@ -0,0 +1,179 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestIPv4OptionsSerializer(t *testing.T) {
+ optCases := []struct {
+ name string
+ option []header.IPv4SerializableOption
+ expect []byte
+ }{
+ {
+ name: "NOP",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ },
+ expect: []byte{1, 0, 0, 0},
+ },
+ {
+ name: "ListEnd",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableListEndOption{},
+ },
+ expect: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{148, 4, 0, 0},
+ }, {
+ name: "NOP and RouterAlert",
+ option: []header.IPv4SerializableOption{
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableRouterAlertOption{},
+ },
+ expect: []byte{1, 148, 4, 0, 0, 0, 0, 0},
+ },
+ }
+
+ for _, opt := range optCases {
+ t.Run(opt.name, func(t *testing.T) {
+ s := header.IPv4OptionsSerializer(opt.option)
+ l := s.Length()
+ if got := len(opt.expect); got != int(l) {
+ t.Fatalf("s.Length() = %d, want = %d", got, l)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with full bytes to ensure padding is being set
+ // correctly.
+ b[i] = 0xFF
+ }
+ if serializedLength := s.Serialize(b); serializedLength != l {
+ t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l)
+ }
+ if diff := cmp.Diff(opt.expect, b); diff != "" {
+ t.Errorf("mismatched serialized option (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
+// fields when options are supplied.
+func TestIPv4EncodeOptions(t *testing.T) {
+ tests := []struct {
+ name string
+ numberOfNops int
+ encodedOptions header.IPv4Options // reply should look like this
+ wantIHL int
+ }{
+ {
+ name: "valid no options",
+ wantIHL: header.IPv4MinimumSize,
+ },
+ {
+ name: "one byte options",
+ numberOfNops: 1,
+ encodedOptions: header.IPv4Options{1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "two byte options",
+ numberOfNops: 2,
+ encodedOptions: header.IPv4Options{1, 1, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "three byte options",
+ numberOfNops: 3,
+ encodedOptions: header.IPv4Options{1, 1, 1, 0},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "four byte options",
+ numberOfNops: 4,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1},
+ wantIHL: header.IPv4MinimumSize + 4,
+ },
+ {
+ name: "five byte options",
+ numberOfNops: 5,
+ encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
+ wantIHL: header.IPv4MinimumSize + 8,
+ },
+ {
+ name: "thirty nine byte options",
+ numberOfNops: 39,
+ encodedOptions: header.IPv4Options{
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 0,
+ },
+ wantIHL: header.IPv4MinimumSize + 40,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops))
+ for i := range serializeOpts {
+ serializeOpts[i] = &header.IPv4SerializableNOPOption{}
+ }
+ paddedOptionLength := serializeOpts.Length()
+ ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength)
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength)
+ hdr := buffer.NewPrependable(int(totalLen))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ // To check the padding works, poison the last byte of the options space.
+ if paddedOptionLength != serializeOpts.Length() {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ ip.Options()[paddedOptionLength-1] = 0xff
+ ip.SetHeaderLength(0)
+ }
+ ip.Encode(&header.IPv4Fields{
+ Options: serializeOpts,
+ })
+ options := ip.Options()
+ wantOptions := test.encodedOptions
+ if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
+ t.Errorf("got IHL of %d, want %d", got, want)
+ }
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(wantOptions) == 0 && len(options) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(wantOptions, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 4e7e5f76a..5580d6a78 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -18,7 +18,6 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
- "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -48,13 +47,15 @@ type IPv6Fields struct {
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
- // PayloadLength is the "payload length" field of an IPv6 packet.
+ // PayloadLength is the "payload length" field of an IPv6 packet, including
+ // the length of all extension headers.
PayloadLength uint16
- // NextHeader is the "next header" field of an IPv6 packet.
- NextHeader uint8
+ // TransportProtocol is the transport layer protocol number. Serialized in the
+ // last "next header" field of the IPv6 header + extension headers.
+ TransportProtocol tcpip.TransportProtocolNumber
- // HopLimit is the "hop limit" field of an IPv6 packet.
+ // HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
// SrcAddr is the "source ip address" of an IPv6 packet.
@@ -62,6 +63,9 @@ type IPv6Fields struct {
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr tcpip.Address
+
+ // ExtensionHeaders are the extension headers following the IPv6 header.
+ ExtensionHeaders IPv6ExtHdrSerializer
}
// IPv6 represents an ipv6 header stored in a byte array.
@@ -148,13 +152,17 @@ const (
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
// be contained within this subnet.
-var IPv6EmptySubnet = func() tcpip.Subnet {
- subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
- if err != nil {
- panic(err)
- }
- return subnet
-}()
+var IPv6EmptySubnet = tcpip.AddressWithPrefix{
+ Address: IPv6Any,
+ PrefixLen: 0,
+}.Subnet()
+
+// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined
+// by RFC 4291 section 2.5.5.
+var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{
+ Address: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00",
+ PrefixLen: 96,
+}.Subnet()
// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
// by RFC 4291 section 2.5.6.
@@ -171,7 +179,7 @@ func (b IPv6) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
-// HopLimit returns the value of the "hop limit" field of the ipv6 header.
+// HopLimit returns the value of the "Hop Limit" field of the ipv6 header.
func (b IPv6) HopLimit() uint8 {
return b[hopLimit]
}
@@ -236,6 +244,11 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
+// SetHopLimit sets the value of the "Hop Limit" field.
+func (b IPv6) SetHopLimit(v uint8) {
+ b[hopLimit] = v
+}
+
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
func (b IPv6) SetNextHeader(v uint8) {
b[IPv6NextHeaderOffset] = v
@@ -248,12 +261,14 @@ func (IPv6) SetChecksum(uint16) {
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
+ extHdr := b[IPv6MinimumSize:]
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
- b[IPv6NextHeaderOffset] = i.NextHeader
b[hopLimit] = i.HopLimit
b.SetSourceAddress(i.SrcAddr)
b.SetDestinationAddress(i.DstAddr)
+ nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
+ b[IPv6NextHeaderOffset] = nextHeader
}
// IsValid performs basic validation on the packet.
@@ -281,7 +296,7 @@ func IsV4MappedAddress(addr tcpip.Address) bool {
return false
}
- return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
+ return IPv4MappedIPv6Subnet.Contains(addr)
}
// IsV6MulticastAddress determines if the provided address is an IPv6
@@ -387,17 +402,6 @@ func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
}
-// IsV6UniqueLocalAddress determines if the provided address is an IPv6
-// unique-local address (within the prefix FC00::/7).
-func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
- if len(addr) != IPv6AddressSize {
- return false
- }
- // According to RFC 4193 section 3.1, a unique local address has the prefix
- // FC00::/7.
- return (addr[0] & 0xfe) == 0xfc
-}
-
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
//
@@ -444,9 +448,6 @@ const (
// LinkLocalScope indicates a link-local address.
LinkLocalScope IPv6AddressScope = iota
- // UniqueLocalScope indicates a unique-local address.
- UniqueLocalScope
-
// GlobalScope indicates a global address.
GlobalScope
)
@@ -464,9 +465,6 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
case IsV6LinkLocalAddress(addr):
return LinkLocalScope, nil
- case IsV6UniqueLocalAddress(addr):
- return UniqueLocalScope, nil
-
default:
return GlobalScope, nil
}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 583c2c5d3..f18981332 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -18,9 +18,12 @@ import (
"bufio"
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
+ "math"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -47,6 +50,11 @@ const (
// IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
// of an IPv6 payload, as per RFC 8200 section 4.7.
IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
+
+ // IPv6UnknownExtHdrIdentifier is reserved by IANA.
+ // https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header
+ // "254 Use for experimentation and testing [RFC3692][RFC4727]"
+ IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254
)
const (
@@ -70,8 +78,8 @@ const (
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
- // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
- // discard from the Fragment Offset.
+ // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
+ // Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetShift = 3
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
@@ -109,6 +117,37 @@ const (
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
)
+// padIPv6OptionsLength returns the total length for IPv6 options of length l
+// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
+func padIPv6OptionsLength(length int) int {
+ return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
+}
+
+// padIPv6Option fills b with the appropriate padding options depending on its
+// length.
+func padIPv6Option(b []byte) {
+ switch len(b) {
+ case 0: // No padding needed.
+ case 1: // Pad with Pad1.
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
+ default: // Pad with PadN.
+ s := b[ipv6ExtHdrOptionPayloadOffset:]
+ for i := range s {
+ s[i] = 0
+ }
+ b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
+ b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
+ }
+}
+
+// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
+// serialize an option at headerOffset with alignment requirements
+// [align]n + alignOffset.
+func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
+ padLen := headerOffset - alignOffset
+ return ((padLen + align - 1) & ^(align - 1)) - padLen
+}
+
// IPv6PayloadHeader is implemented by the various headers that can be found
// in an IPv6 payload.
//
@@ -201,29 +240,55 @@ type IPv6ExtHdrOption interface {
isIPv6ExtHdrOption()
}
-// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
-type IPv6ExtHdrOptionIndentifier uint8
+// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIdentifier uint8
const (
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
- ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
// ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
- ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
+
+ // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
+ // Alert Hop by Hop option as defined in RFC 2711 section 2.1.
+ ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
+
+ // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
+ // option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionTypeOffset = 0
+
+ // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionLengthOffset = 1
+
+ // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
+ // header option as defined in RFC 8200 section 4.2.
+ ipv6ExtHdrOptionPayloadOffset = 2
)
+// ipv6UnknownActionFromIdentifier maps an extension header option's
+// identifier's high bits to the action to take when the identifier is unknown.
+func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
+// is malformed.
+var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
+
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
// header option that is unknown by the parsing utilities.
type IPv6UnknownExtHdrOption struct {
- Identifier IPv6ExtHdrOptionIndentifier
+ Identifier IPv6ExtHdrOptionIdentifier
Data []byte
}
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
- return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+ return ipv6UnknownActionFromIdentifier(o.Identifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
@@ -246,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// options buffer has been exhausted and we are done iterating.
return nil, true, nil
}
- id := IPv6ExtHdrOptionIndentifier(temp)
+ id := IPv6ExtHdrOptionIdentifier(temp)
// If the option identifier indicates the option is a Pad1 option, then we
// know the option does not have Length and Data fields. End processing of
@@ -289,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
continue
+ case ipv6RouterAlertHopByHopOptionIdentifier:
+ var routerAlertValue [ipv6RouterAlertPayloadLength]byte
+ if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil {
+ switch err {
+ case io.EOF, io.ErrUnexpectedEOF:
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ default:
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
+ }
+ } else if n != int(length) {
+ return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
+ }
+ return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
default:
bytes := make([]byte, length)
if n, err := io.ReadFull(&i.reader, bytes); err != nil {
@@ -452,9 +530,11 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Since we consume the iterator, we return the payload as is.
buf = i.payload
- // Mark i as done.
+ // Mark i as done, but keep track of where we were for error reporting.
*i = IPv6PayloadIterator{
nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ headerOffset: i.headerOffset,
+ nextOffset: i.nextOffset,
}
} else {
buf = i.payload.Clone(nil)
@@ -602,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
}
+
+// IPv6SerializableExtHdr provides serialization for IPv6 extension
+// headers.
+type IPv6SerializableExtHdr interface {
+ // identifier returns the assigned IPv6 header identifier for this extension
+ // header.
+ identifier() IPv6ExtensionHeaderIdentifier
+
+ // length returns the total serialized length in bytes of this extension
+ // header, including the common next header and length fields.
+ length() int
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer and with the provided nextHeader value.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto returns the number of bytes that was used to serialize the
+ // receiver. Implementers must only use the number of bytes required to
+ // serialize the receiver. Callers MAY provide a larger buffer than required
+ // to serialize into.
+ serializeInto(nextHeader uint8, b []byte) int
+}
+
+var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
+
+// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
+// options extension header.
+type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
+
+const (
+ // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
+ // in a hop by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrNextHeaderOffset = 0
+
+ // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
+ // by hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrLengthOffset = 1
+
+ // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
+ // hop extension header as defined in RFC 8200 section 4.3.
+ ipv6HopByHopExtHdrOptionsOffset = 2
+
+ // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
+ // words in a hop by hop extension header's length field, as stated in RFC
+ // 8200 section 4.3:
+ // Length of the Hop-by-Hop Options header in 8-octet units,
+ // not including the first 8 octets.
+ ipv6HopByHopExtHdrUnaccountedLenWords = 1
+)
+
+// identifier implements IPv6SerializableExtHdr.
+func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6HopByHopOptionsExtHdrIdentifier
+}
+
+// length implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) length() int {
+ var total int
+ for _, opt := range h {
+ align, alignOffset := opt.alignment()
+ total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
+ total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
+ }
+ // Account for next header and total length fields and add padding.
+ return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
+ totalLength := ipv6HopByHopExtHdrOptionsOffset
+ for _, opt := range h {
+ // Calculate alignment requirements and pad buffer if necessary.
+ align, alignOffset := opt.alignment()
+ padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
+ if padLen != 0 {
+ padIPv6Option(optBuffer[:padLen])
+ totalLength += padLen
+ optBuffer = optBuffer[padLen:]
+ }
+
+ l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
+ optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
+ optBuffer[ipv6ExtHdrOptionLengthOffset] = l
+ l += ipv6ExtHdrOptionPayloadOffset
+ totalLength += int(l)
+ optBuffer = optBuffer[l:]
+ }
+ padded := padIPv6OptionsLength(totalLength)
+ if padded != totalLength {
+ padIPv6Option(optBuffer[:padded-totalLength])
+ totalLength = padded
+ }
+ wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
+ if wordsLen > math.MaxUint8 {
+ panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
+ }
+ b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
+ b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
+ return totalLength
+}
+
+// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
+type IPv6SerializableHopByHopOption interface {
+ // identifier returns the option identifier of this Hop by Hop option.
+ identifier() IPv6ExtHdrOptionIdentifier
+
+ // length returns the *payload* size of the option (not considering the type
+ // and length fields).
+ length() uint8
+
+ // alignment returns the alignment requirements from this option.
+ //
+ // Alignment requirements take the form [align]n + offset as specified in
+ // RFC 8200 section 4.2. The alignment requirement is on the offset between
+ // the option type byte and the start of the hop by hop header.
+ //
+ // align must be a power of 2.
+ alignment() (align int, offset int)
+
+ // serializeInto serializes the receiver into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) uint8
+}
+
+var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
+
+// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
+// RFC 2711 section 2.1.
+type IPv6RouterAlertOption struct {
+ Value IPv6RouterAlertValue
+}
+
+// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
+type IPv6RouterAlertValue uint16
+
+const (
+ // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
+ // Discovery message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertMLD IPv6RouterAlertValue = 0
+ // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
+ // defined in RFC 2711 section 2.1.
+ IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
+ // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
+ // Networks message as defined in RFC 2711 section 2.1.
+ IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
+
+ // ipv6RouterAlertPayloadLength is the length of the Router Alert payload
+ // as defined in RFC 2711.
+ ipv6RouterAlertPayloadLength = 2
+
+ // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
+ // Router Alert option defined as 2n+0 in RFC 2711.
+ ipv6RouterAlertAlignmentRequirement = 2
+
+ // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
+ // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
+ // 2.1.
+ ipv6RouterAlertAlignmentOffsetRequirement = 0
+)
+
+// UnknownAction implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
+ return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
+func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
+ return ipv6RouterAlertHopByHopOptionIdentifier
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) length() uint8 {
+ return ipv6RouterAlertPayloadLength
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (*IPv6RouterAlertOption) alignment() (int, int) {
+ // From RFC 2711 section 2.1:
+ // Alignment requirement: 2n+0.
+ return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
+ binary.BigEndian.PutUint16(b, uint16(o.Value))
+ return ipv6RouterAlertPayloadLength
+}
+
+// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
+type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
+
+// Serialize serializes the provided list of IPv6 extension headers into b.
+//
+// Note, b must be of sufficient size to hold all the headers in s. See
+// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
+// serialized IPv6ExtHdrSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+//
+// Serialize takes the transportProtocol value to be used as the last extension
+// header's Next Header value and returns the header identifier of the first
+// serialized extension header and the total serialized length.
+func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
+ nextHeader := uint8(transportProtocol)
+ if len(s) == 0 {
+ return nextHeader, 0
+ }
+ var totalLength int
+ for i, h := range s[:len(s)-1] {
+ length := h.serializeInto(uint8(s[i+1].identifier()), b)
+ b = b[length:]
+ totalLength += length
+ }
+ totalLength += s[len(s)-1].serializeInto(nextHeader, b)
+ return uint8(s[0].identifier()), totalLength
+}
+
+// Length returns the total number of bytes required to serialize the extension
+// headers.
+func (s IPv6ExtHdrSerializer) Length() int {
+ var totalLength int
+ for _, h := range s {
+ totalLength += h.length()
+ }
+ return totalLength
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
index ab20c5f37..65adc6250 100644
--- a/pkg/tcpip/header/ipv6_extension_headers_test.go
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool
func TestIPv6UnknownExtHdrOption(t *testing.T) {
tests := []struct {
name string
- identifier IPv6ExtHdrOptionIndentifier
+ identifier IPv6ExtHdrOptionIdentifier
expectedUnknownAction IPv6OptionUnknownAction
}{
{
@@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
bytes: []byte{1, 3},
err: io.ErrUnexpectedEOF,
},
+ {
+ name: "Router alert without data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with partial data and Pad1",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with extra data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3},
+ err: ErrMalformedIPv6ExtHdrOption,
+ },
+ {
+ name: "Router alert with missing data",
+ bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1},
+ err: io.ErrUnexpectedEOF,
+ },
}
check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
@@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) {
})
}
}
+
+var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil)
+
+// dummyHbHOptionSerializer provides a generic implementation of
+// IPv6SerializableHopByHopOption for use in tests.
+type dummyHbHOptionSerializer struct {
+ id IPv6ExtHdrOptionIdentifier
+ payload []byte
+ align int
+ alignOffset int
+}
+
+// identifier implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) length() uint8 {
+ return uint8(len(s.payload))
+}
+
+// alignment implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) alignment() (int, int) {
+ align := 1
+ if s.align != 0 {
+ align = s.align
+ }
+ return align, s.alignOffset
+}
+
+// serializeInto implements IPv6SerializableHopByHopOption.
+func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 {
+ return uint8(copy(b, s.payload))
+}
+
+func TestIPv6HopByHopSerializer(t *testing.T) {
+ validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ dummy, ok := serializable.(*dummyHbHOptionSerializer)
+ if !ok {
+ t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable)
+ }
+ unknown, ok := deserialized.(*IPv6UnknownExtHdrOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{})
+ }
+ if dummy.id != unknown.Identifier {
+ t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id)
+ }
+ if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" {
+ t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff)
+ }
+ }
+ tests := []struct {
+ name string
+ nextHeader uint8
+ options []IPv6SerializableHopByHopOption
+ expect []byte
+ validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption)
+ }{
+ {
+ name: "single option",
+ nextHeader: 13,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 15,
+ payload: []byte{9, 8, 7, 6},
+ },
+ },
+ expect: []byte{13, 0, 15, 4, 9, 8, 7, 6},
+ validate: validateDummies,
+ },
+ {
+ name: "short option padN zero",
+ nextHeader: 88,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5},
+ },
+ },
+ expect: []byte{88, 0, 22, 2, 4, 5, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "short option pad1",
+ nextHeader: 11,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 33,
+ payload: []byte{1, 2, 3},
+ },
+ },
+ expect: []byte{11, 0, 33, 3, 1, 2, 3, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "long option padN",
+ nextHeader: 55,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 77,
+ payload: []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ },
+ },
+ expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 2n",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2, 3},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 2,
+ },
+ },
+ expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "two options align 8n+1",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{
+ &dummyHbHOptionSerializer{
+ id: 11,
+ payload: []byte{1, 2},
+ },
+ &dummyHbHOptionSerializer{
+ id: 22,
+ payload: []byte{4, 5, 6},
+ align: 8,
+ alignOffset: 1,
+ },
+ },
+ expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0},
+ validate: validateDummies,
+ },
+ {
+ name: "no options",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{},
+ expect: []byte{33, 0, 1, 4, 0, 0, 0, 0},
+ },
+ {
+ name: "Router Alert",
+ nextHeader: 33,
+ options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}},
+ expect: []byte{33, 0, 5, 2, 0, 0, 1, 0},
+ validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
+ t.Helper()
+ routerAlert, ok := deserialized.(*IPv6RouterAlertOption)
+ if !ok {
+ t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized)
+ }
+ if routerAlert.Value != IPv6RouterAlertMLD {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD)
+ }
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6SerializableHopByHopExtHdr(test.options)
+ length := s.length()
+ if length != len(test.expect) {
+ t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect))
+ }
+ b := make([]byte, length)
+ for i := range b {
+ // Fill the buffer with ones to ensure all padding is correctly set.
+ b[i] = 0xFF
+ }
+ if got := s.serializeInto(test.nextHeader, b); got != length {
+ t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length)
+ }
+ if diff := cmp.Diff(test.expect, b); diff != "" {
+ t.Fatalf("serialization mismatch (-want +got):\n%s", diff)
+ }
+
+ // Deserialize the options and verify them.
+ optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit
+ iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter()
+ for _, testOpt := range test.options {
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ test.validate(t, testOpt, opt)
+ }
+ opt, done, err := iter.Next()
+ if err != nil {
+ t.Fatalf("iter.Next(): %s", err)
+ }
+ if !done {
+ t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ })
+ }
+}
+
+var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil)
+
+// dummyIPv6ExtHdrSerializer provides a generic implementation of
+// IPv6SerializableExtHdr for use in tests.
+//
+// The dummy header always carries the nextHeader value in the first byte.
+type dummyIPv6ExtHdrSerializer struct {
+ id IPv6ExtensionHeaderIdentifier
+ headerContents []byte
+}
+
+// identifier implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier {
+ return s.id
+}
+
+// length implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) length() int {
+ return len(s.headerContents) + 1
+}
+
+// serializeInto implements IPv6SerializableExtHdr.
+func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int {
+ b[0] = nextHeader
+ return copy(b[1:], s.headerContents) + 1
+}
+
+func TestIPv6ExtHdrSerializer(t *testing.T) {
+ tests := []struct {
+ name string
+ headers []IPv6SerializableExtHdr
+ nextHeader tcpip.TransportProtocolNumber
+ expectSerialized []byte
+ expectNextHeader uint8
+ }{
+ {
+ name: "one header",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 15,
+ headerContents: []byte{1, 2, 3, 4},
+ },
+ },
+ nextHeader: TCPProtocolNumber,
+ expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4},
+ expectNextHeader: 15,
+ },
+ {
+ name: "two headers",
+ headers: []IPv6SerializableExtHdr{
+ &dummyIPv6ExtHdrSerializer{
+ id: 22,
+ headerContents: []byte{1, 2, 3},
+ },
+ &dummyIPv6ExtHdrSerializer{
+ id: 23,
+ headerContents: []byte{4, 5, 6},
+ },
+ },
+ nextHeader: ICMPv6ProtocolNumber,
+ expectSerialized: []byte{
+ 23, 1, 2, 3,
+ byte(ICMPv6ProtocolNumber), 4, 5, 6,
+ },
+ expectNextHeader: 22,
+ },
+ {
+ name: "no headers",
+ headers: []IPv6SerializableExtHdr{},
+ nextHeader: UDPProtocolNumber,
+ expectSerialized: []byte{},
+ expectNextHeader: byte(UDPProtocolNumber),
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := IPv6ExtHdrSerializer(test.headers)
+ l := s.Length()
+ if got, want := l, len(test.expectSerialized); got != want {
+ t.Fatalf("got serialized length = %d, want = %d", got, want)
+ }
+ b := make([]byte, l)
+ for i := range b {
+ // Fill the buffer with garbage to make sure we're writing to all bytes.
+ b[i] = 0xFF
+ }
+ nextHeader, serializedLen := s.Serialize(test.nextHeader, b)
+ if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader {
+ t.Errorf(
+ "got s.Serialize(..) = (%d, %d), want = (%d, %d)",
+ nextHeader,
+ serializedLen,
+ test.expectNextHeader,
+ len(test.expectSerialized),
+ )
+ }
+ if diff := cmp.Diff(test.expectSerialized, b); diff != "" {
+ t.Errorf("serialization mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
index 018555a26..9d09f32eb 100644
--- a/pkg/tcpip/header/ipv6_fragment.go
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -27,12 +27,11 @@ const (
idV6 = 4
)
-// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
-// fields of a packet that needs to be encoded.
-type IPv6FragmentFields struct {
- // NextHeader is the "next header" field of an IPv6 fragment.
- NextHeader uint8
+var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
+// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
+// extension header as defined in RFC 8200 section 4.5.
+type IPv6SerializableFragmentExtHdr struct {
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
@@ -43,6 +42,29 @@ type IPv6FragmentFields struct {
Identification uint32
}
+// identifier implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
+ return IPv6FragmentHeader
+}
+
+// length implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) length() int {
+ return IPv6FragmentHeaderSize
+}
+
+// serializeInto implements IPv6SerializableFragmentExtHdr.
+func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
+ // Prevent too many bounds checks.
+ _ = b[IPv6FragmentHeaderSize:]
+ binary.BigEndian.PutUint32(b[idV6:], h.Identification)
+ binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
+ b[nextHdrFrag] = nextHeader
+ if h.M {
+ b[more] |= ipv6FragmentExtHdrMFlagMask
+ }
+ return IPv6FragmentHeaderSize
+}
+
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
@@ -58,16 +80,6 @@ const (
IPv6FragmentHeaderSize = 8
)
-// Encode encodes all the fields of the ipv6 fragment.
-func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
- b[nextHdrFrag] = i.NextHeader
- binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
- if i.M {
- b[more] |= 1
- }
- binary.BigEndian.PutUint32(b[idV6:], i.Identification)
-}
-
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 426a873b1..e3fbd64f3 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -215,48 +215,6 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
}
}
-func TestIsV6UniqueLocalAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid Unique 1",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Valid Unique 2",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Link Local",
- addr: linkLocalAddr,
- expected: false,
- },
- {
- name: "Global",
- addr: globalAddr,
- expected: false,
- },
- {
- name: "IPv4",
- addr: "\x01\x02\x03\x04",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
-
func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
tests := []struct {
name string
@@ -346,7 +304,7 @@ func TestScopeForIPv6Address(t *testing.T) {
{
name: "Unique Local",
addr: uniqueLocalAddr1,
- scope: header.UniqueLocalScope,
+ scope: header.GlobalScope,
err: nil,
},
{
diff --git a/pkg/tcpip/header/mld.go b/pkg/tcpip/header/mld.go
new file mode 100644
index 000000000..ffe03c76a
--- /dev/null
+++ b/pkg/tcpip/header/mld.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // MLDMinimumSize is the minimum size for an MLD message.
+ MLDMinimumSize = 20
+
+ // MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as
+ // per RFC 2710 section 3.
+ MLDHopLimit = 1
+
+ // mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay
+ // field within MLD.
+ mldMaximumResponseDelayOffset = 0
+
+ // mldMulticastAddressOffset is the offset to the Multicast Address field
+ // within MLD.
+ mldMulticastAddressOffset = 4
+)
+
+// MLD is a Multicast Listener Discovery message in an ICMPv6 packet.
+//
+// MLD will only contain the body of an ICMPv6 packet.
+//
+// As per RFC 2710 section 3, MLD messages have the following format (MLD only
+// holds the bytes after the first four bytes in the diagram below):
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Code | Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Maximum Response Delay | Reserved |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | |
+// + +
+// | |
+// + Multicast Address +
+// | |
+// + +
+// | |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+type MLD []byte
+
+// MaximumResponseDelay returns the Maximum Response Delay.
+func (m MLD) MaximumResponseDelay() time.Duration {
+ // As per RFC 2710 section 3.4:
+ //
+ // The Maximum Response Delay field is meaningful only in Query
+ // messages, and specifies the maximum allowed delay before sending a
+ // responding Report, in units of milliseconds. In all other messages,
+ // it is set to zero by the sender and ignored by receivers.
+ return time.Duration(binary.BigEndian.Uint16(m[mldMaximumResponseDelayOffset:])) * time.Millisecond
+}
+
+// SetMaximumResponseDelay sets the Maximum Response Delay field.
+//
+// maxRespDelayMS is the value in milliseconds.
+func (m MLD) SetMaximumResponseDelay(maxRespDelayMS uint16) {
+ binary.BigEndian.PutUint16(m[mldMaximumResponseDelayOffset:], maxRespDelayMS)
+}
+
+// MulticastAddress returns the Multicast Address.
+func (m MLD) MulticastAddress() tcpip.Address {
+ // As per RFC 2710 section 3.5:
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ //
+ // In a Report or Done message, the Multicast Address field holds a
+ // specific IPv6 multicast address to which the message sender is
+ // listening or is ceasing to listen, respectively.
+ return tcpip.Address(m[mldMulticastAddressOffset:][:IPv6AddressSize])
+}
+
+// SetMulticastAddress sets the Multicast Address field.
+func (m MLD) SetMulticastAddress(multicastAddress tcpip.Address) {
+ if n := copy(m[mldMulticastAddressOffset:], multicastAddress); n != IPv6AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, IPv6AddressSize))
+ }
+}
diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go
new file mode 100644
index 000000000..0cecf10d4
--- /dev/null
+++ b/pkg/tcpip/header/mld_test.go
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestMLD(t *testing.T) {
+ b := []byte{
+ // Maximum Response Delay
+ 0, 0,
+
+ // Reserved
+ 0, 0,
+
+ // MulticastAddress
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6,
+ }
+
+ const maxRespDelay = 513
+ binary.BigEndian.PutUint16(b, maxRespDelay)
+
+ mld := MLD(b)
+
+ if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want {
+ t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
+ }
+
+ const newMaxRespDelay = 1234
+ mld.SetMaximumResponseDelay(newMaxRespDelay)
+ if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want {
+ t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
+ }
+
+ if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want {
+ t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want)
+ }
+
+ multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})
+ mld.SetMulticastAddress(multicastAddress)
+ if got := mld.MulticastAddress(); got != multicastAddress {
+ t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress)
+ }
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 5d3975c56..554242f0c 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -298,7 +298,7 @@ func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
return it, nil
}
-// Serialize serializes the provided list of NDP options into o.
+// Serialize serializes the provided list of NDP options into b.
//
// Note, b must be of sufficient size to hold all the options in s. See
// NDPOptionsSerializer.Length for details on the getting the total size
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 5ca75c834..2042f214a 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -109,6 +109,9 @@ traverseExtensions:
fragOffset = extHdr.FragmentOffset()
fragMore = extHdr.More()
}
+ rawPayload := it.AsRawHeader(true /* consume */)
+ extensionsSize = dataClone.Size() - rawPayload.Buf.Size()
+ break traverseExtensions
case header.IPv6RawPayloadHeader:
// We've found the payload after any extensions.
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 39ca774ef..973f06cbc 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -9,7 +9,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c95aef63c..d9f8e3b35 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -32,7 +31,7 @@ type PacketInfo struct {
Pkt *stack.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
- Route stack.Route
+ Route stack.RouteInfo
}
// Notification is the interface for receiving notification from the packet
@@ -231,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket stores outbound packets into the channel.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
e.q.Write(p)
@@ -249,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// WritePackets stores outbound packets into the channel.
func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // Clone r then release its resource so we only get the relevant fields from
- // stack.Route without holding a reference to a NIC's endpoint.
- route := r.Clone()
- route.Release()
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
- Route: route,
+ Route: r.GetFields(),
}
if !e.q.Write(p) {
@@ -271,21 +262,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := PacketInfo{
- Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }),
- Proto: 0,
- GSO: nil,
- }
-
- e.q.Write(p)
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD
index ec92ed623..0ae0d201a 100644
--- a/pkg/tcpip/link/ethernet/BUILD
+++ b/pkg/tcpip/link/ethernet/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -13,3 +13,17 @@ go_library(
"//pkg/tcpip/stack",
],
)
+
+go_test(
+ name = "ethernet_test",
+ size = "small",
+ srcs = ["ethernet_test.go"],
+ deps = [
+ ":ethernet",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index 3eef7cd56..89e3e6164 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -49,10 +49,10 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP
return
}
+ // Note, there is no need to check the destination link address here since
+ // the ethernet hardware filters frames based on their destination addresses.
eth := header.Ethernet(hdr)
- if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
- e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt)
- }
+ e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, eth.DestinationAddress() /* local */, eth.Type() /* protocol */, pkt)
}
// Capabilities implements stack.LinkEndpoint.
@@ -62,7 +62,7 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt)
return e.Endpoint.WritePacket(r, gso, proto, pkt)
}
@@ -71,7 +71,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkAddr := e.Endpoint.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/ethernet/ethernet_test.go b/pkg/tcpip/link/ethernet/ethernet_test.go
new file mode 100644
index 000000000..08a7f1ce1
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet_test.go
@@ -0,0 +1,71 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ethernet_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkDispatcher = (*testNetworkDispatcher)(nil)
+
+type testNetworkDispatcher struct {
+ networkPackets int
+}
+
+func (t *testNetworkDispatcher) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+ t.networkPackets++
+}
+
+func (*testNetworkDispatcher) DeliverOutboundPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+func TestDeliverNetworkPacket(t *testing.T) {
+ const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ otherLinkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ otherLinkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ )
+
+ e := ethernet.New(channel.New(0, 0, linkAddr))
+ var networkDispatcher testNetworkDispatcher
+ e.Attach(&networkDispatcher)
+
+ if networkDispatcher.networkPackets != 0 {
+ t.Fatalf("got networkDispatcher.networkPackets = %d, want = 0", networkDispatcher.networkPackets)
+ }
+
+ // An ethernet frame with a destination link address that is not assigned to
+ // our ethernet link endpoint should still be delivered to the network
+ // dispatcher since the ethernet endpoint is not expected to filter frames.
+ eth := buffer.NewView(header.EthernetMinimumSize)
+ header.Ethernet(eth).Encode(&header.EthernetFields{
+ SrcAddr: otherLinkAddr1,
+ DstAddr: otherLinkAddr2,
+ Type: header.IPv4ProtocolNumber,
+ })
+ e.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: eth.ToVectorisedView(),
+ }))
+ if networkDispatcher.networkPackets != 1 {
+ t.Fatalf("got networkDispatcher.networkPackets = %d, want = 1", networkDispatcher.networkPackets)
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 975309fc8..cb94cbea6 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // enable PACKET_FANOUT mode is the underlying socket is
- // of type AF_PACKET.
- const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ // Enable PACKET_FANOUT mode if the underlying socket is of type
+ // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
+ // prevent gvisor from receiving fragmented packets and the host does the
+ // reassembly on our behalf before delivering the fragments. This makes it
+ // hard to test fragmentation reassembly code in Netstack.
+ const fanoutType = unix.PACKET_FANOUT_HASH
fanoutArg := fanoutID | fanoutType<<16
if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
@@ -410,7 +413,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
}
var builder iovec.Builder
@@ -453,7 +456,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress(), pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
@@ -558,11 +561,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool {
return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0])
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return rawfile.NonBlockingWrite(e.fds[0], vv.ToView())
-}
-
// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 709f829c8..987a34226 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -183,9 +183,8 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
defer c.cleanup()
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- }
+ var r stack.Route
+ r.ResolveWith(raddr)
// Build payload.
payload := buffer.NewView(plen)
@@ -220,7 +219,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(&r, gso, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -324,10 +323,9 @@ func TestPreserveSrcAddress(t *testing.T) {
defer c.cleanup()
// Set LocalLinkAddress in route to the value of the bridged address.
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- LocalLinkAddress: baddr,
- }
+ var r stack.Route
+ r.LocalLinkAddress = baddr
+ r.ResolveWith(raddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
@@ -336,7 +334,7 @@ func TestPreserveSrcAddress(t *testing.T) {
ReserveHeaderBytes: header.EthernetMinimumSize,
Data: buffer.VectorisedView{},
})
- if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -503,7 +501,7 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) {
msgHdrs: make([]rawfile.MMsgHdr, 1),
}
- for i, _ := range d.views {
+ for i := range d.views {
d.views[i] = make([]buffer.View, len(c.config))
}
for i := range d.iovecs {
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 38aa694e4..edca57e4e 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- // There should be an ethernet header at the beginning of vv.
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- // Reject the packet if it's shorter than an ethernet header.
- return tcpip.ErrBadAddress
- }
- linkHeader := header.Ethernet(hdr)
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt)
-
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index e7493e5c5..cbda59775 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 56a611825..22e79ce3a 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -17,7 +17,6 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco
return tcpip.ErrNoRoute
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- // WriteRawPacket doesn't get a route or network address, so there's
- // nowhere to write this.
- return tcpip.ErrNoRoute
-}
-
// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 3e4afcdad..b511d3a31 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -51,7 +51,8 @@ func TestInjectableEndpointDispatch(t *testing.T) {
Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
})
pkt.TransportHeader().Push(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
+ var packetRoute stack.Route
+ packetRoute.RemoteAddress = dstIP
endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
@@ -73,7 +74,8 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
Data: buffer.NewView(0).ToVectorisedView(),
})
pkt.TransportHeader().Push(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
+ var packetRoute stack.Route
+ packetRoute.RemoteAddress = dstIP
endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
index 2cdb23475..00b42b924 100644
--- a/pkg/tcpip/link/nested/BUILD
+++ b/pkg/tcpip/link/nested/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index d40de54df..0ee54c3d5 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -19,7 +19,6 @@ package nested
import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return e.child.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- return e.child.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.
func (e *Endpoint) Wait() {
e.child.Wait()
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
index 3922c2a04..9a1b0c0c2 100644
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -36,14 +36,14 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
// WritePacket implements stack.LinkEndpoint.WritePacket.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress(), r.LocalLinkAddress, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress(), pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, proto)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 523b0d24b..25c364391 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -55,7 +55,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network
// remote address from the perspective of the other end of the pipe
// (e.linked). Similarly, the remote address from the perspective of this
// endpoint is the local address on the other end.
- e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
}))
@@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList,
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.
-func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- panic("not implemented")
-}
-
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 1d0079bd6..5bea598eb 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -13,7 +13,6 @@ go_library(
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index fc1e34fc7..b7458b620 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -155,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
- newRoute := r.Clone()
- pkt.EgressRoute = &newRoute
+ pkt.EgressRoute = r
pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
@@ -179,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
for pkt := pkts.Front(); pkt != nil; {
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
nxt := pkt.Next()
- // Since qdisc can hold onto a packet for long we should Clone
- // the route here to ensure it doesn't get released while the
- // packet is still in our queue.
- newRoute := pkt.EgressRoute.Clone()
- pkt.EgressRoute = &newRoute
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
@@ -197,13 +190,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
return enqueued, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // TODO(gvisor.dev/issue/3267): Queue these packets as well once
- // WriteRawPacket takes PacketBuffer instead of VectorisedView.
- return e.lower.WriteRawPacket(vv)
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() {
e.lower.Wait()
diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
index eb5abb906..45adcbccb 100644
--- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
+++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
@@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
+ s.EgressRoute.Acquire()
q.list.PushBack(s)
q.used++
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 7fb8a6c49..5660418fa 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -204,7 +204,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt)
views := pkt.Views()
// Transmit the packet.
@@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB
panic("not implemented")
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- views := vv.Views()
- // Transmit the packet.
- e.mu.Lock()
- ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
- if !ok {
- return tcpip.ErrWouldBlock
- }
-
- return nil
-}
-
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
// to the network stack.
func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 22d5c97f1..dd2e1a125 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -260,9 +260,8 @@ func TestSimpleSend(t *testing.T) {
defer c.cleanup()
// Prepare route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
for iters := 1000; iters > 0; iters-- {
func() {
@@ -341,10 +340,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
// Set both remote and local link address in route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- LocalLinkAddress: newLocalLinkAddress,
- }
+ var r stack.Route
+ r.LocalLinkAddress = newLocalLinkAddress
+ r.ResolveWith(remoteLinkAddr)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
// WritePacket panics given a prependable with anything less than
@@ -395,9 +393,8 @@ func TestFillTxQueue(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -444,9 +441,8 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
c.txq.rx.Flush()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -509,9 +505,8 @@ func TestFillTxMemory(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
@@ -557,9 +552,8 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
defer c.cleanup()
// Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.ResolveWith(remoteLinkAddr)
buf := buffer.NewView(100)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index b3e8c4b92..1a2cc39eb 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -53,16 +53,35 @@ type endpoint struct {
nested.Endpoint
writer io.Writer
maxPCAPLen uint32
+ logPrefix string
}
var _ stack.GSOEndpoint = (*endpoint)(nil)
var _ stack.LinkEndpoint = (*endpoint)(nil)
var _ stack.NetworkDispatcher = (*endpoint)(nil)
+type direction int
+
+const (
+ directionSend = iota
+ directionRecv
+)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- sniffer := &endpoint{}
+ return NewWithPrefix(lower, "")
+}
+
+// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets prefixed with logPrefix as they traverse
+// the endpoint.
+//
+// logPrefix is prepended to the log line without any separators.
+// E.g. logPrefix = "NIC:en0/" will produce log lines like
+// "NIC:en0/send udp [...]".
+func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint {
+ sniffer := &endpoint{logPrefix: logPrefix}
sniffer.Endpoint.Init(lower, sniffer)
return sniffer
}
@@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dumpPacket("recv", nil, protocol, pkt)
+ e.dumpPacket(directionRecv, nil, protocol, pkt)
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
@@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
-func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- logPacket(prefix, protocol, pkt, gso)
+ logPacket(e.logPrefix, dir, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Size()
@@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
// forwards the request to the lower endpoint.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.dumpPacket("send", gso, protocol, pkt)
+ e.dumpPacket(directionSend, gso, protocol, pkt)
}
return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return e.Endpoint.WriteRawPacket(vv)
-}
-
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
+func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -201,6 +212,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var fragmentOffset uint16
var moreFragments bool
+ var directionPrefix string
+ switch dir {
+ case directionSend:
+ directionPrefix = "send"
+ case directionRecv:
+ directionPrefix = "recv"
+ default:
+ panic(fmt.Sprintf("unrecognized direction: %d", dir))
+ }
+
// Clone the packet buffer to not modify the original.
//
// We don't clone the original packet buffer so that the new packet buffer
@@ -242,21 +263,22 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- if parse.ARP(pkt) {
+ if !parse.ARP(pkt) {
return
}
arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
- "%s arp %s (%s) -> %s (%s) valid:%t",
+ "%s%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
+ directionPrefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
arp.IsValid(),
)
return
default:
- log.Infof("%s unknown network protocol", prefix)
+ log.Infof("%s%s unknown network protocol", prefix, directionPrefix)
return
}
@@ -300,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
icmpType = "info reply"
}
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
@@ -335,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
@@ -391,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
}
default:
- log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto)
return
}
@@ -399,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 4c14f55d3..bfac358f4 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -76,29 +76,13 @@ func (d *Device) Release(ctx context.Context) {
}
}
-// NICID returns the NIC ID of the device.
-//
-// Must only be called after the device has been attached to an endpoint.
-func (d *Device) NICID() tcpip.NICID {
- d.mu.RLock()
- defer d.mu.RUnlock()
-
- if d.endpoint == nil {
- panic("called NICID on a device that has not been attached")
- }
-
- return d.endpoint.nicID
-}
-
// SetIff services TUNSETIFF ioctl(2) request.
-//
-// Returns true if a new NIC was created; false if an existing one was attached.
-func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) {
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.endpoint != nil {
- return false, syserror.EINVAL
+ return syserror.EINVAL
}
// Input validations.
@@ -106,7 +90,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error)
isTap := flags&linux.IFF_TAP != 0
supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
- return false, syserror.EINVAL
+ return syserror.EINVAL
}
prefix := "tun"
@@ -119,18 +103,18 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error)
linkCaps |= stack.CapabilityResolutionRequired
}
- endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return false, syserror.EINVAL
+ return syserror.EINVAL
}
d.endpoint = endpoint
d.notifyHandle = d.endpoint.AddNotify(d)
d.flags = flags
- return created, nil
+ return nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
@@ -138,13 +122,13 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, false, syserror.EOPNOTSUPP
+ return nil, syserror.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
continue
}
- return endpoint, false, nil
+ return endpoint, nil
}
}
@@ -167,12 +151,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
})
switch err {
case nil:
- return endpoint, true, nil
+ return endpoint, nil
case tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
- return nil, false, syserror.EINVAL
+ return nil, syserror.EINVAL
}
}
}
@@ -280,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
// If the packet does not already have link layer header, and the route
// does not exist, we can't compute it. This is possibly a raw packet, tun
// device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" {
+ if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 {
return nil, false
}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index ee84c3d96..9b4602c1b 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -11,7 +11,6 @@ go_library(
deps = [
"//pkg/gate",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
@@ -25,7 +24,6 @@ go_test(
library = ":waitable",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index b152a0f26..cf0077f43 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -24,7 +24,6 @@ package waitable
import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, err
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if !e.writeGate.Enter() {
- return nil
- }
-
- err := e.lower.WriteRawPacket(vv)
- e.writeGate.Leave()
- return err
-}
-
// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
// and waits for inflight ones to finish before returning.
func (e *Endpoint) WaitWrite() {
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 94827fc56..cf7fb5126 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -18,7 +18,6 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.
return pkts.Len(), nil
}
-func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- e.writeCount++
- return nil
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index c118a2929..9ebf31b78 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -7,13 +7,16 @@ go_test(
size = "small",
srcs = [
"ip_test.go",
+ "multicast_group_test.go",
],
deps = [
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 33a4a0720..3d5c0d270 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -31,17 +31,15 @@ import (
const (
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
-
- // ProtocolAddress is the address expected by the ARP endpoint.
- ProtocolAddress = tcpip.Address("arp")
)
-var _ stack.AddressableEndpoint = (*endpoint)(nil)
+// ARP endpoints need to implement stack.NetworkEndpoint because the stack
+// considers the layer above the link-layer a network layer; the only
+// facility provided by the stack to deliver packets to a layer above
+// the link-layer is via stack.NetworkEndpoint.HandlePacket.
var _ stack.NetworkEndpoint = (*endpoint)(nil)
type endpoint struct {
- stack.AddressableEndpointState
-
protocol *protocol
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
@@ -87,7 +85,7 @@ func (e *endpoint) Disable() {
}
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
-func (e *endpoint) DefaultTTL() uint8 {
+func (*endpoint) DefaultTTL() uint8 {
return 0
}
@@ -100,25 +98,23 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {
- e.AddressableEndpointState.Cleanup()
-}
+func (*endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
-func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return ProtocolNumber
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -216,9 +212,8 @@ func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
func (p *protocol) DefaultPrefixLen() int { return 0 }
-func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- h := header.ARP(v)
- return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
+ return "", ""
}
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
@@ -228,7 +223,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
linkAddrCache: linkAddrCache,
nud: nud,
}
- e.AddressableEndpointState.Init(e)
return e
}
@@ -311,10 +305,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
}
// NewProtocol returns an ARP network protocol.
-//
-// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
-// address must be added to every NIC that should respond to ARP requests. See
-// ProtocolAddress for more details.
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return &protocol{stack: s}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 087ee9c66..a25cba513 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -200,9 +200,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
}
- if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("AddAddress for arp failed: %v", err)
- }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -322,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
copy(h.HardwareAddressSender(), test.senderLinkAddr)
copy(h.ProtocolAddressSender(), test.senderAddr)
copy(h.ProtocolAddressTarget(), test.targetAddr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
- })
+ }))
if !test.isValid {
// No packets should be sent after receiving an invalid ARP request.
@@ -439,11 +436,14 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.NetProto = protocol
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 47fb63290..429af69ee 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -18,7 +18,6 @@ go_template_instance(
go_library(
name = "fragmentation",
srcs = [
- "frag_heap.go",
"fragmentation.go",
"reassembler.go",
"reassembler_list.go",
@@ -38,7 +37,6 @@ go_test(
name = "fragmentation_test",
size = "small",
srcs = [
- "frag_heap_test.go",
"fragmentation_test.go",
"reassembler_test.go",
],
@@ -47,6 +45,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
"//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
deleted file mode 100644
index 0b570d25a..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-type fragment struct {
- offset uint16
- vv buffer.VectorisedView
-}
-
-type fragHeap []fragment
-
-func (h *fragHeap) Len() int {
- return len(*h)
-}
-
-func (h *fragHeap) Less(i, j int) bool {
- return (*h)[i].offset < (*h)[j].offset
-}
-
-func (h *fragHeap) Swap(i, j int) {
- (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
-}
-
-func (h *fragHeap) Push(x interface{}) {
- *h = append(*h, x.(fragment))
-}
-
-func (h *fragHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[:n-1]
- return x
-}
-
-// reassamble empties the heap and returns a VectorisedView
-// containing a reassambled version of the fragments inside the heap.
-func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
- curr := heap.Pop(h).(fragment)
- views := curr.vv.Views()
- size := curr.vv.Size()
-
- if curr.offset != 0 {
- return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
- }
-
- for h.Len() > 0 {
- curr := heap.Pop(h).(fragment)
- if int(curr.offset) < size {
- curr.vv.TrimFront(size - int(curr.offset))
- } else if int(curr.offset) > size {
- return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
- }
- size += curr.vv.Size()
- views = append(views, curr.vv.Views()...)
- }
- return buffer.NewVectorisedView(size, views), nil
-}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
deleted file mode 100644
index 9ececcb9f..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap_test.go
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-var reassambleTestCases = []struct {
- comment string
- in []fragment
- want buffer.VectorisedView
-}{
- {
- comment: "Non-overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Non-overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Duplicated packets",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(1, "0"),
- },
- {
- comment: "Overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(2, "01")},
- {offset: 1, vv: vv(2, "12")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(2, "12")},
- {offset: 0, vv: vv(2, "01")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping subset in-order",
- in: []fragment{
- {offset: 0, vv: vv(3, "012")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(3, "012"),
- },
- {
- comment: "Overlapping subset out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(3, "012")},
- },
- want: vv(3, "012"),
- },
-}
-
-func TestReassamble(t *testing.T) {
- for _, c := range reassambleTestCases {
- t.Run(c.comment, func(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- for _, f := range c.in {
- heap.Push(&h, f)
- }
- got, err := h.reassemble()
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
- }
- })
- }
-}
-
-func TestReassambleFailsForNonZeroOffset(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when the first packet had offset != 0")
- }
-}
-
-func TestReassambleFailsForHoles(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
- heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when there was a hole in the packet")
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 936601287..1af87d713 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -46,9 +46,17 @@ const (
)
var (
- // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // ErrInvalidArgs indicates to the caller that an invalid argument was
// provided.
ErrInvalidArgs = errors.New("invalid args")
+
+ // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
+ // with another one.
+ ErrFragmentOverlap = errors.New("overlapping fragments")
+
+ // ErrFragmentConflict indicates that, during reassembly, some fragments are
+ // in conflict with one another.
+ ErrFragmentConflict = errors.New("conflicting fragments")
)
// FragmentID is the identifier for a fragment.
@@ -71,16 +79,25 @@ type FragmentID struct {
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
type Fragmentation struct {
- mu sync.Mutex
- highLimit int
- lowLimit int
- reassemblers map[FragmentID]*reassembler
- rList reassemblerList
- size int
- timeout time.Duration
- blockSize uint16
- clock tcpip.Clock
- releaseJob *tcpip.Job
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[FragmentID]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+ blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
+ timeoutHandler TimeoutHandler
+}
+
+// TimeoutHandler is consulted if a packet reassembly has timed out.
+type TimeoutHandler interface {
+ // OnReassemblyTimeout will be called with the first fragment (or nil, if the
+ // first fragment has not been received) of a packet whose reassembly has
+ // timed out.
+ OnReassemblyTimeout(pkt *stack.PacketBuffer)
}
// NewFragmentation creates a new Fragmentation.
@@ -97,7 +114,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -111,12 +128,13 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
}
f := &Fragmentation{
- reassemblers: make(map[FragmentID]*reassembler),
- highLimit: highMemoryLimit,
- lowLimit: lowMemoryLimit,
- timeout: reassemblingTimeout,
- blockSize: blockSize,
- clock: clock,
+ reassemblers: make(map[FragmentID]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ blockSize: blockSize,
+ clock: clock,
+ timeoutHandler: timeoutHandler,
}
f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
@@ -136,16 +154,8 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
// proto is the protocol number marked in the fragment being processed. It has
// to be given here outside of the FragmentID struct because IPv6 should not use
// the protocol to identify a fragment.
-//
-// releaseCB is a callback that will run when the fragment reassembly of a
-// packet is complete or cancelled. releaseCB take a a boolean argument which is
-// true iff the reassembly is cancelled due to timeout. releaseCB should be
-// passed only with the first fragment of a packet. If more than one releaseCB
-// are passed for the same packet, only the first releaseCB will be saved for
-// the packet and the succeeding ones will be dropped by running them
-// immediately with a false argument.
func (f *Fragmentation) Process(
- id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) (
+ id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
buffer.VectorisedView, uint8, bool, error) {
if first > last {
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
@@ -160,10 +170,9 @@ func (f *Fragmentation) Process(
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
}
- if l := vv.Size(); l < int(fragmentSize) {
- return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ if l := pkt.Data.Size(); l != int(fragmentSize) {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
}
- vv.CapLength(int(fragmentSize))
f.mu.Lock()
r, ok := f.reassemblers[id]
@@ -179,15 +188,9 @@ func (f *Fragmentation) Process(
f.releaseReassemblersLocked()
}
}
- if releaseCB != nil {
- if !r.setCallback(releaseCB) {
- // We got a duplicate callback. Release it immediately.
- releaseCB(false /* timedOut */)
- }
- }
f.mu.Unlock()
- res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
+ res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
@@ -231,7 +234,9 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) {
f.size = 0
}
- r.release(timedOut) // releaseCB may run.
+ if h := f.timeoutHandler; timedOut && h != nil {
+ h.OnReassemblyTimeout(r.pkt)
+ }
}
// releaseReassemblersLocked releases already-expired reassemblers, then
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 5dcd10730..3a79688a8 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// reassembleTimeout is dummy timeout used for testing, where the clock never
@@ -40,13 +41,19 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
return buffer.NewVectorisedView(size, views)
}
+func pkt(size int, pieces ...string) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv(size, pieces...),
+ })
+}
+
type processInput struct {
id FragmentID
first uint16
last uint16
more bool
proto uint8
- vv buffer.VectorisedView
+ pkt *stack.PacketBuffer
}
type processOutput struct {
@@ -63,8 +70,8 @@ var processTestCases = []struct {
{
comment: "One ID",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -74,8 +81,8 @@ var processTestCases = []struct {
{
comment: "Next Header protocol mismatch",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -85,10 +92,10 @@ var processTestCases = []struct {
{
comment: "Two IDs",
in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -102,17 +109,17 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
if err != nil {
- t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
- in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
+ t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
}
if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)",
- in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView())
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView())
}
if done != c.out[i].done {
t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
@@ -236,11 +243,11 @@ func TestReassemblingTimeout(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock := faketime.NewManualClock()
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil)
for _, event := range test.events {
clock.Advance(event.clockAdvance)
if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil)
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data))
if err != nil {
t.Fatalf("%s: f.Process failed: %s", event.name, err)
}
@@ -257,17 +264,17 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil)
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1"))
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil)
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil)
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3"))
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -281,11 +288,11 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil)
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0"))
got := f.size
want := 1
@@ -327,6 +334,7 @@ func TestErrors(t *testing.T) {
last: 3,
more: true,
data: "012",
+ err: ErrInvalidArgs,
},
{
name: "exact block size with more and too little data",
@@ -376,8 +384,8 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil)
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil)
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
}
@@ -498,57 +506,92 @@ func TestPacketFragmenter(t *testing.T) {
}
}
-func TestReleaseCallback(t *testing.T) {
+type testTimeoutHandler struct {
+ pkt *stack.PacketBuffer
+}
+
+func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ h.pkt = pkt
+}
+
+func TestTimeoutHandler(t *testing.T) {
const (
proto = 99
)
- var result int
- var callbackReasonIsTimeout bool
- cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut }
- cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut }
+ pk1 := pkt(1, "1")
+ pk2 := pkt(1, "2")
+
+ type processParam struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ }
tests := []struct {
- name string
- callbacks []func(bool)
- timeout bool
- wantResult int
- wantCallbackReasonIsTimeout bool
+ name string
+ params []processParam
+ wantError bool
+ wantPkt *stack.PacketBuffer
}{
{
- name: "callback runs on release",
- callbacks: []func(bool){cb1},
- timeout: false,
- wantResult: 1,
- wantCallbackReasonIsTimeout: false,
- },
- {
- name: "first callback is nil",
- callbacks: []func(bool){nil, cb2},
- timeout: false,
- wantResult: 2,
- wantCallbackReasonIsTimeout: false,
+ name: "onTimeout runs",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
},
{
- name: "two callbacks - first one is set",
- callbacks: []func(bool){cb1, cb2},
- timeout: false,
- wantResult: 1,
- wantCallbackReasonIsTimeout: false,
+ name: "no first fragment",
+ params: []processParam{
+ {
+ first: 1,
+ last: 1,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: false,
+ wantPkt: nil,
},
{
- name: "callback runs on timeout",
- callbacks: []func(bool){cb1},
- timeout: true,
- wantResult: 1,
- wantCallbackReasonIsTimeout: true,
+ name: "second pkt is ignored",
+ params: []processParam{
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ {
+ first: 0,
+ last: 0,
+ more: true,
+ pkt: pk2,
+ },
+ },
+ wantError: false,
+ wantPkt: pk1,
},
{
- name: "no callbacks",
- callbacks: []func(bool){nil},
- timeout: false,
- wantResult: 0,
- wantCallbackReasonIsTimeout: false,
+ name: "invalid args - first is greater than last",
+ params: []processParam{
+ {
+ first: 1,
+ last: 0,
+ more: true,
+ pkt: pk1,
+ },
+ },
+ wantError: true,
+ wantPkt: nil,
},
}
@@ -556,29 +599,31 @@ func TestReleaseCallback(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- result = 0
- callbackReasonIsTimeout = false
+ handler := &testTimeoutHandler{pkt: nil}
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler)
- for i, cb := range test.callbacks {
- _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb)
- if err != nil {
+ for _, p := range test.params {
+ if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError {
t.Errorf("f.Process error = %s", err)
}
}
-
- r, ok := f.reassemblers[id]
- if !ok {
- t.Fatalf("Reassemberr not found")
- }
- f.release(r, test.timeout)
-
- if result != test.wantResult {
- t.Errorf("got result = %d, want = %d", result, test.wantResult)
+ if !test.wantError {
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatal("Reassembler not found")
+ }
+ f.release(r, true)
}
- if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout {
- t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout)
+ switch {
+ case handler.pkt != nil && test.wantPkt == nil:
+ t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView())
+ case handler.pkt == nil && test.wantPkt != nil:
+ t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView())
+ case handler.pkt != nil && test.wantPkt != nil:
+ if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" {
+ t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
+ }
}
})
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index c0cc0bde0..9b20bb1d8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -15,19 +15,21 @@
package fragmentation
import (
- "container/heap"
- "fmt"
"math"
+ "sort"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
type hole struct {
- first uint16
- last uint16
- deleted bool
+ first uint16
+ last uint16
+ filled bool
+ final bool
+ data buffer.View
}
type reassembler struct {
@@ -37,84 +39,139 @@ type reassembler struct {
proto uint8
mu sync.Mutex
holes []hole
- deleted int
- heap fragHeap
+ filled int
done bool
creationTime int64
- callback func(bool)
+ pkt *stack.PacketBuffer
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
- holes: make([]hole, 0, 16),
- heap: make(fragHeap, 0, 8),
creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
- first: 0,
- last: math.MaxUint16,
- deleted: false})
+ first: 0,
+ last: math.MaxUint16,
+ filled: false,
+ final: true,
+ })
return r
}
-// updateHoles updates the list of holes for an incoming fragment and
-// returns true iff the fragment filled at least part of an existing hole.
-func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
- used := false
- for i := range r.holes {
- if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
- continue
- }
- used = true
- r.deleted++
- r.holes[i].deleted = true
- if first > r.holes[i].first {
- r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
- }
- if last < r.holes[i].last && more {
- r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
- }
- }
- return used
-}
-
-func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
- consumed := 0
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, 0, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, 0, nil
}
- // For IPv6, it is possible to have different Protocol values between
- // fragments of a packet (because, unlike IPv4, the Protocol is not used to
- // identify a fragment). In this case, only the Protocol of the first
- // fragment must be used as per RFC 8200 Section 4.5.
- //
- // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded
- // here (instead of just the protocol) because most IP options should be
- // derived from the first fragment.
- if first == 0 {
- r.proto = proto
- }
- if r.updateHoles(first, last, more) {
- // We store the incoming packet only if it filled some holes.
- heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
- consumed = vv.Size()
+
+ var holeFound bool
+ var consumed int
+ for i := range r.holes {
+ currentHole := &r.holes[i]
+
+ if last < currentHole.first || currentHole.last < first {
+ continue
+ }
+ // For IPv6, overlaps with an existing fragment are explicitly forbidden by
+ // RFC 8200 section 4.5:
+ // If any of the fragments being reassembled overlap with any other
+ // fragments being reassembled for the same packet, reassembly of that
+ // packet must be abandoned and all the fragments that have been received
+ // for that packet must be discarded, and no ICMP error messages should be
+ // sent.
+ //
+ // It is not explicitly forbidden for IPv4, but to keep parity with Linux we
+ // disallow it as well:
+ // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349
+ if first < currentHole.first || currentHole.last < last {
+ // Incoming fragment only partially fits in the free hole.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap
+ }
+ if !more {
+ if !currentHole.final || currentHole.filled && currentHole.last != last {
+ // We have another final fragment, which does not perfectly overlap.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ }
+ }
+
+ holeFound = true
+ if currentHole.filled {
+ // Incoming fragment is a duplicate.
+ continue
+ }
+
+ // We are populating the current hole with the payload and creating a new
+ // hole for any unfilled ranges on either end.
+ if first > currentHole.first {
+ r.holes = append(r.holes, hole{
+ first: currentHole.first,
+ last: first - 1,
+ filled: false,
+ final: false,
+ })
+ }
+ if last < currentHole.last && more {
+ r.holes = append(r.holes, hole{
+ first: last + 1,
+ last: currentHole.last,
+ filled: false,
+ final: currentHole.final,
+ })
+ currentHole.final = false
+ }
+ v := pkt.Data.ToOwnedView()
+ consumed = v.Size()
r.size += consumed
+ // Update the current hole to precisely match the incoming fragment.
+ r.holes[i] = hole{
+ first: first,
+ last: last,
+ filled: true,
+ final: currentHole.final,
+ data: v,
+ }
+ r.filled++
+ // For IPv6, it is possible to have different Protocol values between
+ // fragments of a packet (because, unlike IPv4, the Protocol is not used to
+ // identify a fragment). In this case, only the Protocol of the first
+ // fragment must be used as per RFC 8200 Section 4.5.
+ //
+ // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP
+ // options received in the first fragment should be used - and they should
+ // override options from following fragments.
+ if first == 0 {
+ r.pkt = pkt
+ r.proto = proto
+ }
+
+ break
}
- // Check if all the holes have been deleted and we are ready to reassamble.
- if r.deleted < len(r.holes) {
+ if !holeFound {
+ // Incoming fragment is beyond end.
+ return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict
+ }
+
+ // Check if all the holes have been filled and we are ready to reassemble.
+ if r.filled < len(r.holes) {
return buffer.VectorisedView{}, 0, false, consumed, nil
}
- res, err := r.heap.reassemble()
- if err != nil {
- return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
+
+ sort.Slice(r.holes, func(i, j int) bool {
+ return r.holes[i].first < r.holes[j].first
+ })
+
+ var size int
+ views := make([]buffer.View, 0, len(r.holes))
+ for _, hole := range r.holes {
+ views = append(views, hole.data)
+ size += hole.data.Size()
}
- return res, r.proto, true, consumed, nil
+ return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
@@ -124,24 +181,3 @@ func (r *reassembler) checkDoneOrMark() bool {
r.mu.Unlock()
return prev
}
-
-func (r *reassembler) setCallback(c func(bool)) bool {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.callback != nil {
- return false
- }
- r.callback = c
- return true
-}
-
-func (r *reassembler) release(timedOut bool) {
- r.mu.Lock()
- callback := r.callback
- r.callback = nil
- r.mu.Unlock()
-
- if callback != nil {
- callback(timedOut)
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index fa2a70dc8..2ff03eeeb 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -16,115 +16,175 @@ package fragmentation
import (
"math"
- "reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-type updateHolesInput struct {
- first uint16
- last uint16
- more bool
+type processParams struct {
+ first uint16
+ last uint16
+ more bool
+ pkt *stack.PacketBuffer
+ wantDone bool
+ wantError error
}
-var holesTestCases = []struct {
- comment string
- in []updateHolesInput
- want []hole
-}{
- {
- comment: "No fragments. Expected holes: {[0 -> inf]}.",
- in: []updateHolesInput{},
- want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
- },
- {
- comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
- in: []updateHolesInput{{first: 0, last: 1, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: false},
+func TestReassemblerProcess(t *testing.T) {
+ const proto = 99
+
+ v := func(size int) buffer.View {
+ payload := buffer.NewView(size)
+ for i := 1; i < size; i++ {
+ payload[i] = uint8(i) * 3
+ }
+ return payload
+ }
+
+ pkt := func(size int) *stack.PacketBuffer {
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v(size).ToVectorisedView(),
+ })
+ }
+
+ var tests = []struct {
+ name string
+ params []processParams
+ want []hole
+ }{
+ {
+ name: "No fragments",
+ params: nil,
+ want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
},
- },
- {
- comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
- {first: 3, last: math.MaxUint16, deleted: false},
+ {
+ name: "One fragment at beginning",
+ params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment at the end. Expected holes: {[0, 0]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
+ {
+ name: "One fragment in the middle",
+ params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: false, data: v(2)},
+ {first: 0, last: 0, filled: false, final: false},
+ {first: 3, last: math.MaxUint16, filled: false, final: true},
+ },
},
- },
- {
- comment: "One fragment completing a packet. Expected holes: {}.",
- in: []updateHolesInput{{first: 0, last: 1, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
+ {
+ name: "One fragment at the end",
+ params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
+ want: []hole{
+ {first: 1, last: 2, filled: true, final: true, data: v(2)},
+ {first: 0, last: 0, filled: false},
+ },
},
- },
- {
- comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 1, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "One fragment completing a packet",
+ params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- },
- {
- comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 2, more: true},
- {first: 2, last: 3, more: false},
+ {
+ name: "Two fragments completing a packet with a duplicate",
+ params: []processParams{
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 1, filled: true, final: false, data: v(2)},
+ {first: 2, last: 3, filled: true, final: true, data: v(2)},
+ },
},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 3, last: math.MaxUint16, deleted: true},
+ {
+ name: "Two fragments completing a packet with a partial duplicate",
+ params: []processParams{
+ {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
+ {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
+ {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
+ },
+ want: []hole{
+ {first: 0, last: 3, filled: true, final: false, data: v(4)},
+ {first: 4, last: 5, filled: true, final: true, data: v(2)},
+ },
+ },
+ {
+ name: "Two overlapping fragments",
+ params: []processParams{
+ {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
+ {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
+ },
+ want: []hole{
+ {first: 0, last: 10, filled: true, final: false, data: v(11)},
+ {first: 11, last: math.MaxUint16, filled: false, final: true},
+ },
+ },
+ {
+ name: "Two final fragments with different ends",
+ params: []processParams{
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 10, last: 14, filled: true, final: true, data: v(5)},
+ {first: 0, last: 9, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
+ },
+ {
+ name: "Two final fragments - duplicate, with different ends",
+ params: []processParams{
+ {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
+ {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
+ },
+ want: []hole{
+ {first: 5, last: 14, filled: true, final: true, data: v(10)},
+ {first: 0, last: 4, filled: false, final: false},
+ },
},
- },
-}
-
-func TestUpdateHoles(t *testing.T) {
- for _, c := range holesTestCases {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- for _, i := range c.in {
- r.updateHoles(i.first, i.last, i.more)
- }
- if !reflect.DeepEqual(r.holes, c.want) {
- t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
- }
}
-}
-func TestSetCallback(t *testing.T) {
- result := 0
- reasonTimeout := false
-
- cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut }
- cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut }
-
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- if !r.setCallback(cb1) {
- t.Errorf("setCallback failed")
- }
- if r.setCallback(cb2) {
- t.Errorf("setCallback should fail if one is already set")
- }
- r.release(true)
- if result != 1 {
- t.Errorf("got result = %d, want = 1", result)
- }
- if !reasonTimeout {
- t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ for _, param := range test.params {
+ _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
+ if done != param.wantDone || err != param.wantError {
+ t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
+ }
+ }
+ if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" {
+ t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
new file mode 100644
index 000000000..ca1247c1e
--- /dev/null
+++ b/pkg/tcpip/network/ip/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ip",
+ srcs = ["generic_multicast_protocol.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ ],
+)
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = ["generic_multicast_protocol_test.go"],
+ deps = [
+ ":ip",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/faketime",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
new file mode 100644
index 000000000..f2f0e069c
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -0,0 +1,676 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ip holds IPv4/IPv6 common utilities.
+package ip
+
+import (
+ "fmt"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// hostState is the state a host may be in for a multicast group.
+type hostState int
+
+// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1
+// (RFC 2710 section 5). Even though the states are generic across both IGMPv2
+// and MLDv1, IGMPv2 terminology will be used.
+//
+// ______________receive query______________
+// | |
+// | _____send or receive report_____ |
+// | | | |
+// V | V |
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+ |
+// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | -
+// +-------+ +-----------+ +------------+ +-------------------+ +--------+
+// | ^ | ^ | ^ | ^
+// | | | | | | | |
+// ---------- ------- ---------- -------------
+// initialize new send inital fail to send send or receive
+// group membership report delayed report report
+//
+// Not shown in the diagram above, but any state may transition into the non
+// member state when a group is left.
+const (
+ // nonMember is the "'Non-Member' state, when the host does not belong to the
+ // group on the interface. This is the initial state for all memberships on
+ // all network interfaces; it requires no storage in the host."
+ //
+ // 'Non-Listener' is the MLDv1 term used to describe this state.
+ //
+ // This state is used to keep track of groups that have been joined locally,
+ // but without advertising the membership to the network.
+ nonMember hostState = iota
+
+ // pendingMember is a newly joined member that is waiting to successfully send
+ // the initial set of reports.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the initial report needs to be sent.
+ //
+ // MAY NOT transition to the idle member state from this state.
+ pendingMember
+
+ // delayingMember is the "'Delaying Member' state, when the host belongs to
+ // the group on the interface and has a report delay timer running for that
+ // membership."
+ //
+ // 'Delaying Listener' is the MLDv1 term used to describe this state.
+ delayingMember
+
+ // queuedDelayingMember is a delayingMember that failed to send a report after
+ // its delayed report timer fired. Hosts in this state are waiting to attempt
+ // retransmission of the delayed report.
+ //
+ // This is not an RFC defined state; it is an implementation specific state to
+ // track that the delayed report needs to be sent.
+ //
+ // May transition to idle member if a report is received for a group.
+ queuedDelayingMember
+
+ // idleMember is the "Idle Member" state, when the host belongs to the group
+ // on the interface and does not have a report delay timer running for that
+ // membership.
+ //
+ // 'Idle Listener' is the MLDv1 term used to describe this state.
+ idleMember
+)
+
+func (s hostState) isDelayingMember() bool {
+ switch s {
+ case nonMember, pendingMember, idleMember:
+ return false
+ case delayingMember, queuedDelayingMember:
+ return true
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", s))
+ }
+}
+
+// multicastGroupState holds the Generic Multicast Protocol state for a
+// multicast group.
+type multicastGroupState struct {
+ // joins is the number of times the group has been joined.
+ joins uint64
+
+ // state holds the host's state for the group.
+ state hostState
+
+ // lastToSendReport is true if we sent the last report for the group. It is
+ // used to track whether there are other hosts on the subnet that are also
+ // members of the group.
+ //
+ // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page
+ // 8 for MLDv1.
+ lastToSendReport bool
+
+ // delayedReportJob is used to delay sending responses to membership report
+ // messages in order to reduce duplicate reports from multiple hosts on the
+ // interface.
+ //
+ // Must not be nil.
+ delayedReportJob *tcpip.Job
+}
+
+// GenericMulticastProtocolOptions holds options for the generic multicast
+// protocol.
+type GenericMulticastProtocolOptions struct {
+ // Rand is the source of random numbers.
+ Rand *rand.Rand
+
+ // Clock is the clock used to create timers.
+ Clock tcpip.Clock
+
+ // Protocol is the implementation of the variant of multicast group protocol
+ // in use.
+ Protocol MulticastGroupProtocol
+
+ // MaxUnsolicitedReportDelay is the maximum amount of time to wait between
+ // transmitting unsolicited reports.
+ //
+ // Unsolicited reports are transmitted when a group is newly joined.
+ MaxUnsolicitedReportDelay time.Duration
+
+ // AllNodesAddress is a multicast address that all nodes on a network should
+ // be a member of.
+ //
+ // This address will not have the generic multicast protocol performed on it;
+ // it will be left in the non member/listener state, and packets will never
+ // be sent for it.
+ AllNodesAddress tcpip.Address
+}
+
+// MulticastGroupProtocol is a multicast group protocol whose core state machine
+// can be represented by GenericMulticastProtocolState.
+type MulticastGroupProtocol interface {
+ // Enabled indicates whether the generic multicast protocol will be
+ // performed.
+ //
+ // When enabled, the protocol may transmit report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // packets.
+ //
+ // When disabled, the protocol will still keep track of locally joined groups,
+ // it just won't transmit and handle packets, or update groups' state.
+ Enabled() bool
+
+ // SendReport sends a multicast report for the specified group address.
+ //
+ // Returns false if the caller should queue the report to be sent later. Note,
+ // returning false does not mean that the receiver hit an error.
+ SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
+
+ // SendLeave sends a multicast leave for the specified group address.
+ SendLeave(groupAddress tcpip.Address) *tcpip.Error
+}
+
+// GenericMulticastProtocolState is the per interface generic multicast protocol
+// state.
+//
+// There is actually no protocol named "Generic Multicast Protocol". Instead,
+// the term used to refer to a generic multicast protocol that applies to both
+// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
+// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
+//
+// Callers must synchronize accesses to the generic multicast protocol state;
+// GenericMulticastProtocolState obtains no locks in any of its methods. The
+// only exception to this is GenericMulticastProtocolState's timer/job callbacks
+// which will obtain the lock provided to the GenericMulticastProtocolState when
+// it is initialized.
+//
+// GenericMulticastProtocolState.Init MUST be called before calling any of
+// the methods on GenericMulticastProtocolState.
+//
+// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the
+// multicast group protocol is disabled so that leave messages may be sent.
+type GenericMulticastProtocolState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
+ opts GenericMulticastProtocolOptions
+
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
+
+ // protocolMU is the mutex used to protect the protocol.
+ protocolMU *sync.RWMutex
+}
+
+// Init initializes the Generic Multicast Protocol state.
+//
+// Must only be called once for the lifetime of g; Init will panic if it is
+// called twice.
+//
+// The GenericMulticastProtocolState will only grab the lock when timers/jobs
+// fire.
+//
+// Note: the methods on opts.Protocol will always be called while protocolMU is
+// held.
+func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
+ if g.memberships != nil {
+ panic("attempted to initialize generic membership protocol state twice")
+ }
+
+ *g = GenericMulticastProtocolState{
+ opts: opts,
+ memberships: make(map[tcpip.Address]multicastGroupState),
+ protocolMU: protocolMU,
+ }
+}
+
+// MakeAllNonMemberLocked transitions all groups to the non-member state.
+//
+// The groups will still be considered joined locally.
+//
+// MUST be called when the multicast group protocol is disabled.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// InitializeGroupsLocked initializes each group, as if they were newly joined
+// but without affecting the groups' join count.
+//
+// Must only be called after calling MakeAllNonMember as a group should not be
+// initialized while it is not in the non-member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ for groupAddress, info := range g.memberships {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// SendQueuedReportsLocked attempts to send reports for groups that failed to
+// send reports during their last attempt.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() {
+ for groupAddress, info := range g.memberships {
+ switch info.state {
+ case nonMember, delayingMember, idleMember:
+ case pendingMember:
+ // pendingMembers failed to send their initial unsolicited report so try
+ // to send the report and queue the extra unsolicited reports.
+ g.maybeSendInitialReportLocked(groupAddress, &info)
+ case queuedDelayingMember:
+ // queuedDelayingMembers failed to send their delayed reports so try to
+ // send the report and transition them to the idle state.
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ default:
+ panic(fmt.Sprintf("unrecognized host state = %d", info.state))
+ }
+ g.memberships[groupAddress] = info
+ }
+}
+
+// JoinGroupLocked handles joining a new group.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) {
+ if info, ok := g.memberships[groupAddress]; ok {
+ // The group has already been joined.
+ info.joins++
+ g.memberships[groupAddress] = info
+ return
+ }
+
+ info := multicastGroupState{
+ // Since we just joined the group, its count is 1.
+ joins: 1,
+ // The state will be updated below, if required.
+ state: nonMember,
+ lastToSendReport: false,
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ if !g.opts.Protocol.Enabled() {
+ panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress))
+ }
+
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
+ }
+
+ g.maybeSendDelayedReportLocked(groupAddress, &info)
+ g.memberships[groupAddress] = info
+ }),
+ }
+
+ if g.opts.Protocol.Enabled() {
+ g.initializeNewMemberLocked(groupAddress, &info)
+ }
+
+ g.memberships[groupAddress] = info
+}
+
+// IsLocallyJoinedRLocked returns true if the group is locally joined.
+//
+// Precondition: g.protocolMU must be read locked.
+func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool {
+ _, ok := g.memberships[groupAddress]
+ return ok
+}
+
+// LeaveGroupLocked handles leaving the group.
+//
+// Returns false if the group is not currently joined.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool {
+ info, ok := g.memberships[groupAddress]
+ if !ok {
+ return false
+ }
+
+ if info.joins == 0 {
+ panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress))
+ }
+ info.joins--
+ if info.joins != 0 {
+ // If we still have outstanding joins, then do nothing further.
+ g.memberships[groupAddress] = info
+ return true
+ }
+
+ g.transitionToNonMemberLocked(groupAddress, &info)
+ delete(g.memberships, groupAddress)
+ return true
+}
+
+// HandleQueryLocked handles a query message with the specified maximum response
+// time.
+//
+// If the group address is unspecified, then reports will be scheduled for all
+// joined groups.
+//
+// Report(s) will be scheduled to be sent after a random duration between 0 and
+// the maximum response time.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 2.4 (for IGMPv2),
+ //
+ // In a Membership Query message, the group address field is set to zero
+ // when sending a General Query, and set to the group address being
+ // queried when sending a Group-Specific Query.
+ //
+ // As per RFC 2710 section 3.6 (for MLDv1),
+ //
+ // In a Query message, the Multicast Address field is set to zero when
+ // sending a General Query, and set to a specific IPv6 multicast address
+ // when sending a Multicast-Address-Specific Query.
+ if groupAddress.Unspecified() {
+ // This is a general query as the group address is unspecified.
+ for groupAddress, info := range g.memberships {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+ } else if info, ok := g.memberships[groupAddress]; ok {
+ g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
+ g.memberships[groupAddress] = info
+ }
+}
+
+// HandleReportLocked handles a report message.
+//
+// If the report is for a joined group, any active delayed report will be
+// cancelled and the host state for the group transitions to idle.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
+ if !g.opts.Protocol.Enabled() {
+ return
+ }
+
+ // As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
+ //
+ // If the host receives another host's Report (version 1 or 2) while it has
+ // a timer running, it stops its timer for the specified group and does not
+ // send a Report
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // If a node receives another node's Report from an interface for a
+ // multicast address while it has a timer running for that same address
+ // on that interface, it stops its timer and does not send a Report for
+ // that address, thus suppressing duplicate reports on the link.
+ if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() {
+ info.delayedReportJob.Cancel()
+ info.lastToSendReport = false
+ info.state = idleMember
+ g.memberships[groupAddress] = info
+ }
+}
+
+// initializeNewMemberLocked initializes a new group membership.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != nonMember {
+ panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ info.lastToSendReport = false
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ info.state = idleMember
+ return
+ }
+
+ info.state = pendingMember
+ g.maybeSendInitialReportLocked(groupAddress, info)
+}
+
+// maybeSendInitialReportLocked attempts to start transmission of the initial
+// set of reports after newly joining a group.
+//
+// Host must be in pending member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state != pendingMember {
+ panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a host joins a multicast group, it should immediately transmit an
+ // unsolicited Version 2 Membership Report for that group" ... "it is
+ // recommended that it be repeated".
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node starts listening to a multicast address on an interface,
+ // it should immediately transmit an unsolicited Report for that address
+ // on that interface, in case it is the first listener on the link. To
+ // cover the possibility of the initial Report being lost or damaged, it
+ // is recommended that it be repeated once or twice after short delays
+ // [Unsolicited Report Interval].
+ //
+ // TODO(gvisor.dev/issue/4901): Support a configurable number of initial
+ // unsolicited reports.
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay)
+ }
+}
+
+// maybeSendDelayedReportLocked attempts to send the delayed report.
+//
+// Host must be in pending, delaying or queued delaying member state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if !info.state.isDelayingMember() {
+ panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state))
+ }
+
+ sent, err := g.opts.Protocol.SendReport(groupAddress)
+ if err == nil && sent {
+ info.lastToSendReport = true
+ info.state = idleMember
+ } else {
+ info.state = queuedDelayingMember
+ }
+}
+
+// maybeSendLeave attempts to send a leave message.
+func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) {
+ if !g.opts.Protocol.Enabled() || !lastToSendReport {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // Okay to ignore the error here as if packet write failed, the multicast
+ // routers will eventually drop our membership anyways. If the interface is
+ // being disabled or removed, the generic multicast protocol's should be
+ // cleared eventually.
+ //
+ // As per RFC 2236 section 3 page 5 (for IGMPv2),
+ //
+ // When a router receives a Report, it adds the group being reported to
+ // the list of multicast group memberships on the network on which it
+ // received the Report and sets the timer for the membership to the
+ // [Group Membership Interval]. Repeated Reports refresh the timer. If
+ // no Reports are received for a particular group before this timer has
+ // expired, the router assumes that the group has no local members and
+ // that it need not forward remotely-originated multicasts for that
+ // group onto the attached network.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // When a router receives a Report from a link, if the reported address
+ // is not already present in the router's list of multicast address
+ // having listeners on that link, the reported address is added to the
+ // list, its timer is set to [Multicast Listener Interval], and its
+ // appearance is made known to the router's multicast routing component.
+ // If a Report is received for a multicast address that is already
+ // present in the router's list, the timer for that address is reset to
+ // [Multicast Listener Interval]. If an address's timer expires, it is
+ // assumed that there are no longer any listeners for that address
+ // present on the link, so it is deleted from the list and its
+ // disappearance is made known to the multicast routing component.
+ //
+ // The requirement to send a leave message is also optional (it MAY be
+ // skipped):
+ //
+ // As per RFC 2236 section 6 page 8 (for IGMPv2),
+ //
+ // "send leave" for the group on the interface. If the interface
+ // state says the Querier is running IGMPv1, this action SHOULD be
+ // skipped. If the flag saying we were the last host to report is
+ // cleared, this action MAY be skipped. The Leave Message is sent to
+ // the ALL-ROUTERS group (224.0.0.2).
+ //
+ // As per RFC 2710 section 5 page 8 (for MLDv1),
+ //
+ // "send done" for the address on the interface. If the flag saying
+ // we were the last node to report is cleared, this action MAY be
+ // skipped. The Done message is sent to the link-scope all-routers
+ // address (FF02::2).
+ _ = g.opts.Protocol.SendLeave(groupAddress)
+}
+
+// transitionToNonMemberLocked transitions the given multicast group the the
+// non-member/listener state.
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
+ if info.state == nonMember {
+ return
+ }
+
+ info.delayedReportJob.Cancel()
+ g.maybeSendLeave(groupAddress, info.lastToSendReport)
+ info.lastToSendReport = false
+ info.state = nonMember
+}
+
+// setDelayTimerForAddressRLocked sets timer to send a delay report.
+//
+// Precondition: g.protocolMU MUST be read locked.
+func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
+ if info.state == nonMember {
+ return
+ }
+
+ if groupAddress == g.opts.AllNodesAddress {
+ // As per RFC 2236 section 6 page 10 (for IGMPv2),
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ //
+ // As per RFC 2710 section 5 page 10 (for MLDv1),
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ return
+ }
+
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // If a timer for the group is already unning, it is reset to the random
+ // value only if the requested Max Response Time is less than the remaining
+ // value of the running timer.
+ //
+ // As per RFC 2710 section 4 page 5 (for MLDv1),
+ //
+ // If a timer for any address is already running, it is reset to the new
+ // random value only if the requested Maximum Response Delay is less than
+ // the remaining value of the running timer.
+ if info.state == delayingMember {
+ // TODO: Reset the timer if time remaining is greater than maxResponseTime.
+ return
+ }
+
+ info.state = delayingMember
+ info.delayedReportJob.Cancel()
+ info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime))
+}
+
+// calculateDelayTimerDuration returns a random time between (0, maxRespTime].
+func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration {
+ // As per RFC 2236 section 3 page 3 (for IGMPv2),
+ //
+ // When a host receives a Group-Specific Query, it sets a delay timer to a
+ // random value selected from the range (0, Max Response Time]...
+ //
+ // As per RFC 2710 section 4 page 6 (for MLDv1),
+ //
+ // When a node receives a Multicast-Address-Specific Query, if it is
+ // listening to the queried Multicast Address on the interface from
+ // which the Query was received, it sets a delay timer for that address
+ // to a random value selected from the range [0, Maximum Response Delay],
+ // as above.
+ if maxRespTime == 0 {
+ return 0
+ }
+ return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime)))
+}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
new file mode 100644
index 000000000..85593f211
--- /dev/null
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -0,0 +1,806 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+)
+
+const (
+ addr1 = tcpip.Address("\x01")
+ addr2 = tcpip.Address("\x02")
+ addr3 = tcpip.Address("\x03")
+ addr4 = tcpip.Address("\x04")
+
+ maxUnsolicitedReportDelay = time.Second
+)
+
+var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
+
+type mockMulticastGroupProtocolProtectedFields struct {
+ sync.RWMutex
+
+ genericMulticastGroup ip.GenericMulticastProtocolState
+ sendReportGroupAddrCount map[tcpip.Address]int
+ sendLeaveGroupAddrCount map[tcpip.Address]int
+ makeQueuePackets bool
+ disabled bool
+}
+
+type mockMulticastGroupProtocol struct {
+ t *testing.T
+
+ mu mockMulticastGroupProtocolProtectedFields
+}
+
+func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.initLocked()
+ opts.Protocol = m
+ m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
+}
+
+func (m *mockMulticastGroupProtocol) initLocked() {
+ m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
+ m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.disabled = !v
+}
+
+func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.makeQueuePackets = v
+}
+
+func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.JoinGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleReportLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
+}
+
+func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
+}
+
+func (m *mockMulticastGroupProtocol) makeAllNonMember() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initializeGroups() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.InitializeGroupsLocked()
+}
+
+func (m *mockMulticastGroupProtocol) sendQueuedReports() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.genericMulticastGroup.SendQueuedReportsLocked()
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be read locked.
+func (m *mockMulticastGroupProtocol) Enabled() bool {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
+ }
+
+ return !m.mu.disabled
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+
+ m.mu.sendReportGroupAddrCount[groupAddress]++
+ return !m.mu.makeQueuePackets, nil
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: m.mu must be locked.
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+
+ m.mu.sendLeaveGroupAddrCount[groupAddress]++
+ return nil
+}
+
+func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendReportGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendReportGroupAddresses {
+ sendReportGroupAddrCount[a] = 1
+ }
+
+ sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
+ for _, a := range sendLeaveGroupAddresses {
+ sendLeaveGroupAddrCount[a] = 1
+ }
+
+ diff := cmp.Diff(
+ &mockMulticastGroupProtocol{
+ mu: mockMulticastGroupProtocolProtectedFields{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
+ },
+ m,
+ cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
+ // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
+ cmp.FilterPath(
+ func(p cmp.Path) bool {
+ switch p.Last().String() {
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
+ return true
+ }
+ return false
+ },
+ cmp.Ignore(),
+ ),
+ )
+ m.initLocked()
+ return diff
+}
+
+func TestJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendReports bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendReports: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendReports: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(0)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ // Joining a group should send a report immediately and another after
+ // a random interval between 0 and the maximum unsolicited report delay.
+ mgp.joinGroup(test.addr)
+ if test.shouldSendReports {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ shouldSendMessages bool
+ }{
+ {
+ name: "Normal group",
+ addr: addr1,
+ shouldSendMessages: true,
+ },
+ {
+ name: "All-nodes group",
+ addr: addr2,
+ shouldSendMessages: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(1)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr2,
+ })
+
+ mgp.joinGroup(test.addr)
+ if test.shouldSendMessages {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Leaving a group should send a leave report immediately and cancel any
+ // delayed reports.
+ {
+
+ if !mgp.leaveGroup(test.addr) {
+ t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
+ }
+ }
+ if test.shouldSendMessages {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleReport(t *testing.T) {
+ tests := []struct {
+ name string
+ reportAddr tcpip.Address
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ reportAddr: "",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ reportAddr: "\x00",
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ reportAddr: addr1,
+ expectReportsFor: []tcpip.Address{addr2},
+ },
+ {
+ name: "Specified all-nodes",
+ reportAddr: addr3,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified other",
+ reportAddr: addr4,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(2)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report for a group we have a timer scheduled for should
+ // cancel our delayed report timer for the group.
+ mgp.handleReport(test.reportAddr)
+ if len(test.expectReportsFor) != 0 {
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestHandleQuery(t *testing.T) {
+ tests := []struct {
+ name string
+ queryAddr tcpip.Address
+ maxDelay time.Duration
+ expectReportsFor []tcpip.Address
+ }{
+ {
+ name: "Unpecified empty",
+ queryAddr: "",
+ maxDelay: 0,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Unpecified any",
+ queryAddr: "\x00",
+ maxDelay: 1,
+ expectReportsFor: []tcpip.Address{addr1, addr2},
+ },
+ {
+ name: "Specified",
+ queryAddr: addr1,
+ maxDelay: 2,
+ expectReportsFor: []tcpip.Address{addr1},
+ },
+ {
+ name: "Specified all-nodes",
+ queryAddr: addr3,
+ maxDelay: 3,
+ expectReportsFor: nil,
+ },
+ {
+ name: "Specified other",
+ queryAddr: addr4,
+ maxDelay: 4,
+ expectReportsFor: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should make us schedule a new delayed report if it
+ // is a query directed at us or a general query.
+ mgp.handleQuery(test.queryAddr, test.maxDelay)
+ if len(test.expectReportsFor) != 0 {
+ clock.Advance(test.maxDelay)
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestJoinCount(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: time.Second,
+ })
+
+ // Set the join count to 2 for a group.
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // Only the first join should trigger a report to be sent.
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should still be considered joined after leaving once.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
+ }
+ if !mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ // A leave report should only be sent once the join count reaches 0.
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving once more should actually remove us from the group.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Group should no longer be joined so we should not have anything to
+ // leave.
+ if mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestMakeAllNonMemberAndInitialize(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ AllNodesAddress: addr3,
+ })
+
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr3)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should send the leave reports for each but still consider them locally
+ // joined.
+ mgp.makeAllNonMember()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ for _, group := range []tcpip.Address{addr1, addr2, addr3} {
+ if !mgp.isLocallyJoined(group) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
+ }
+ }
+
+ // Should send the initial set of unsolcited reports.
+ mgp.initializeGroups()
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should have no more messages to send.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+// TestGroupStateNonMember tests that groups do not send packets when in the
+// non-member state, but are still considered locally joined.
+func TestGroupStateNonMember(t *testing.T) {
+ mgp := mockMulticastGroupProtocol{t: t}
+ clock := faketime.NewManualClock()
+
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(3)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+ mgp.setEnabled(false)
+
+ // Joining groups should not send any reports.
+ mgp.joinGroup(addr1)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.joinGroup(addr2)
+ if !mgp.isLocallyJoined(addr1) {
+ t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a query should not send any reports.
+ mgp.handleQuery(addr1, time.Nanosecond)
+ // Generic multicast protocol timers are expected to take the job mutex.
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Leaving groups should not send any leave messages.
+ if !mgp.leaveGroup(addr1) {
+ t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
+ }
+ if mgp.isLocallyJoined(addr1) {
+ t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestQueuedPackets(t *testing.T) {
+ clock := faketime.NewManualClock()
+ mgp := mockMulticastGroupProtocol{t: t}
+ mgp.init(ip.GenericMulticastProtocolOptions{
+ Rand: rand.New(rand.NewSource(4)),
+ Clock: clock,
+ MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
+ })
+
+ // Joining should trigger a SendReport, but mgp should report that we did not
+ // send the packet.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr1)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report timer should have been cancelled since we did not send
+ // the initial report earlier.
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send (we should be idle).
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query but mock being unable to send reports again.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Mock being able to send reports again - we should have a packet queued to
+ // send.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receive a query again, but mock being unable to send reports.
+ mgp.setQueuePackets(true)
+ mgp.handleQuery(addr1, time.Nanosecond)
+ clock.Advance(time.Nanosecond)
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Receiving a report should should transition us into the idle member state,
+ // even if we had a packet queued. We should no longer have any packets to
+ // send.
+ mgp.handleReport(addr1)
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // When we fail to send the initial set of reports, incoming reports should
+ // not affect a newly joined group's reports from being sent.
+ mgp.setQueuePackets(true)
+ mgp.joinGroup(addr2)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ mgp.handleReport(addr2)
+ // Attempting to send queued reports while still unable to send reports should
+ // not change the host state.
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // Mock being able to successfully send the report.
+ mgp.setQueuePackets(false)
+ mgp.sendQueuedReports()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ // The delayed report (sent after the initial report) should now be sent.
+ clock.Advance(maxUnsolicitedReportDelay)
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anything else to send.
+ mgp.sendQueuedReports()
+ clock.Advance(time.Hour)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index c7d26e14f..3005973d7 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -34,16 +35,16 @@ import (
)
const (
- localIPv4Addr = "\x0a\x00\x00\x01"
- remoteIPv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02")
+ ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00")
+ ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00")
+ ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03")
+ localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
+ ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00")
+ ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
nicID = 1
)
@@ -192,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu
panic("not implemented")
}
-func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
@@ -206,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -222,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -299,6 +296,10 @@ func (t *testInterface) Enabled() bool {
return !t.mu.disabled
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) setEnabled(v bool) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -343,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -549,7 +550,7 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -558,59 +559,135 @@ func TestIPv4Send(t *testing.T) {
}
}
-func TestIPv4Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
+func TestReceive(t *testing.T) {
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ v4 bool
+ epAddr tcpip.AddressWithPrefix
+ handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ v4: true,
+ epAddr: localIPv4Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const totalLen = header.IPv4MinimumSize + 30 /* payload length */
+
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ TTL: ipv4.DefaultTTL,
+ Protocol: 10,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if ok := parse.IPv4(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
},
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ v4: false,
+ epAddr: localIPv6Addr.WithPrefix(),
+ handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
+ const payloadLen = 30
+ view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: payloadLen,
+ TransportProtocol: 10,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
+ })
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
- totalLen := header.IPv4MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, _, _, ok := parse.IPv6(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(pkt)
+ },
+ },
}
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: test.v4,
+ },
+ }
+ ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject)
+ defer ep.Close()
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
+ }
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ } else {
+ ep.DecRef()
+ }
+
+ stat := s.Stats().IP.PacketsReceived
+ if got := stat.Value(); got != 0 {
+ t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ test.handlePacket(t, ep, &nic)
+ if nic.testObject.dataCalls != 1 {
+ t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ }
+ if got := stat.Value(); got != 1 {
+ t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ })
}
}
@@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -705,8 +778,18 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.typ = c.expectedTyp
nic.testObject.extra = c.expectedExtra
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
@@ -716,7 +799,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
func TestIPv4FragmentationReceive(t *testing.T) {
- s := buildDummyStack(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
testObject: testObject{
@@ -774,11 +859,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
nic.testObject.dstAddr = localIPv4Addr
nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
// Send first segment.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag1.ToVectorisedView(),
@@ -786,7 +866,18 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
+
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv4Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
+
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
@@ -799,7 +890,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
@@ -843,7 +933,7 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -852,61 +942,6 @@ func TestIPv6Send(t *testing.T) {
}
}
-func TestIPv6Receive(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- totalLen := header.IPv6MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv6MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
-
- r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
- r.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
- }
-}
-
func TestIPv6ReceiveControl(t *testing.T) {
newUint16 := func(v uint16) *uint16 { return &v }
@@ -933,13 +968,6 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
- r, err := buildIPv6Route(
- localIPv6Addr,
- "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
- )
- if err != nil {
- t.Fatal(err)
- }
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
s := buildDummyStack(t)
@@ -965,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Create the outer IPv6 header.
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIPv6Addr,
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 20,
+ SrcAddr: outerSrcAddr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -979,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp.SetIdent(0xdead)
icmp.SetSequence(0xbeef)
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: localIPv6Addr,
- DstAddr: remoteIPv6Addr,
- })
-
+ var extHdrs header.IPv6ExtHdrSerializer
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
- ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
- frag.Encode(&header.IPv6FragmentFields{
- NextHeader: 10,
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
FragmentOffset: *c.fragmentOffset,
M: true,
Identification: 0x12345678,
})
}
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
+ ExtensionHeaders: extHdrs,
+ })
+
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
@@ -1018,8 +1045,17 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := localIPv6Addr.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
+ }
pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
@@ -1052,7 +1088,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
dataBuf := [dataLen]byte{1, 2, 3, 4}
data := dataBuf[:]
- ipv4Options := header.IPv4Options{0, 1, 0, 1}
+ ipv4Options := header.IPv4OptionsSerializer{
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ &header.IPv4SerializableListEndOption{},
+ &header.IPv4SerializableNOPOption{},
+ }
+
+ expectOptions := header.IPv4Options{
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ byte(header.IPv4OptionListEndType),
+ byte(header.IPv4OptionNOPType),
+ }
ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
@@ -1202,7 +1250,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := header.IPv4MinimumSize + ipv4Options.AllocationSize()
+ ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1225,7 +1273,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1235,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1247,7 +1295,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.AllocationSize()))
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
ip.Encode(&header.IPv4Fields{
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
@@ -1266,7 +1314,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
netHdr := pkt.NetworkHeader()
- hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
if len(netHdr.View()) != hdrLen {
t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
}
@@ -1276,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
checker.DstAddr(remoteIPv4Addr),
checker.IPv4HeaderLength(hdrLen),
checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(ipv4Options),
+ checker.IPv4Options(expectOptions),
checker.IPPayload(data),
)
},
@@ -1295,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1338,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ // NB: we're lying about transport protocol here to verify the raw
+ // fragment header bytes.
+ TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return hdr.View().ToVectorisedView()
},
@@ -1373,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1408,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- NextHeader: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: header.IPv4Any,
+ TransportProtocol: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 6252614ec..32f53f217 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "ipv4",
srcs = [
"icmp.go",
+ "igmp.go",
"ipv4.go",
],
visibility = ["//visibility:public"],
@@ -17,6 +18,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -24,7 +26,10 @@ go_library(
go_test(
name = "ipv4_test",
size = "small",
- srcs = ["ipv4_test.go"],
+ srcs = [
+ "igmp_test.go",
+ "ipv4_test.go",
+ ],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 9b5e37fee..8e392f86c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4PacketsReceived
+ received := stats.ICMP.V4.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -90,7 +90,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
iph := header.IPv4(pkt.NetworkHeader().View())
var newOptions header.IPv4Options
- if len(iph) > header.IPv4MinimumSize {
+ if opts := iph.Options(); len(opts) != 0 {
// RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
// type ICMP packets):
// If a Record Route and/or Time Stamp option is received in an
@@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
} else {
op = &optionUsageReceive{}
}
- aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op)
+ aux, tmp, err := e.processIPOptions(pkt, opts, op)
if err != nil {
switch {
case
@@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.Echo.Increment()
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if !e.protocol.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return
@@ -290,6 +290,13 @@ type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in
+// transit to its final destination, as per RFC 792 page 6, Time Exceeded
+// Message.
+type icmpReasonTTLExceeded struct{}
+
+func (*icmpReasonTTLExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -342,17 +349,37 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
return nil
}
+ // If we hit a TTL Exceeded error, then we know we are operating as a router.
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ //
+ // ...
+ //
+ // Code 0 may be received from a gateway. ...
+ //
+ // Note, Code 0 is the TTL exceeded error.
+ //
+ // If we are operating as a router/gateway, don't use the packet's destination
+ // address as the response's source address as we should not not own the
+ // destination address of a packet we are forwarding.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonTTLExceeded); ok {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4PacketsSent
+ sent := p.stack.Stats().ICMP.V4.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
@@ -454,6 +481,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonTTLExceeded:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4TTLExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
@@ -483,3 +514,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
counter.Increment()
return nil
}
+
+// OnReassemblyTimeout implements fragmentation.TimeoutHandler.
+func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ // OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792:
+ //
+ // If a host reassembling a fragmented datagram cannot complete the
+ // reassembly due to missing fragments within its time limit it discards the
+ // datagram, and it may send a time exceeded message.
+ //
+ // If fragment zero is not available then no time exceeded need be sent at
+ // all.
+ if pkt != nil {
+ p.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
new file mode 100644
index 000000000..da88d65d1
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -0,0 +1,345 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // igmpV1PresentDefault is the initial state for igmpV1Present in the
+ // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is
+ // the initial state."
+ igmpV1PresentDefault = 0
+
+ // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18
+ // See note on igmpState.igmpV1Present for more detail.
+ v1RouterPresentTimeout = 400 * time.Second
+
+ // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router
+ // will send General Queries with the Max Response Time set to 0. This MUST
+ // be interpreted as a value of 100 (10 seconds)."
+ //
+ // Note that the Max Response Time field is a value in units of deciseconds.
+ v1MaxRespTime = 10 * time.Second
+
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited IGMP reports.
+ //
+ // Obtained from RFC 2236 Section 8.10, Page 19.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// IGMPOptions holds options for IGMP.
+type IGMPOptions struct {
+ // Enabled indicates whether IGMP will be performed.
+ //
+ // When enabled, IGMP may transmit IGMP report and leave messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // IGMP packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
+
+// igmpState is the per-interface IGMP state.
+//
+// igmpState.init() MUST be called after creating an IGMP state.
+type igmpState struct {
+ // The IPv4 endpoint this igmpState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+
+ // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
+ // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
+ // Membership Reports in response to its Queries, and will not pay
+ // attention to Version 2 Membership Reports. Therefore, a state variable
+ // MUST be kept for each interface, describing whether the multicast
+ // Querier on that interface is running IGMPv1 or IGMPv2. This variable
+ // MUST be based upon whether or not an IGMPv1 query was heard in the last
+ // [Version 1 Router Present Timeout] seconds".
+ //
+ // Must be accessed with atomic operations. Holds a value of 1 when true, 0
+ // when false.
+ igmpV1Present uint32
+
+ // igmpV1Job is scheduled when this interface receives an IGMPv1 style
+ // message, upon expiration the igmpV1Present flag is cleared.
+ // igmpV1Job may not be nil once igmpState is initialized.
+ igmpV1Job *tcpip.Job
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) Enabled() bool {
+ // No need to perform IGMP on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled()
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ igmpType := header.IGMPv2MembershipReport
+ if igmp.v1Present() {
+ igmpType = header.IGMPv1MembershipReport
+ }
+ return igmp.writePacket(groupAddress, groupAddress, igmpType)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ // As per RFC 2236 Section 6, Page 8: "If the interface state says the
+ // Querier is running IGMPv1, this action SHOULD be skipped. If the flag
+ // saying we were the last host to report is cleared, this action MAY be
+ // skipped."
+ if igmp.v1Present() {
+ return nil
+ }
+ _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup)
+ return err
+}
+
+// init sets up an igmpState struct, and is required to be called before using
+// a new igmpState.
+//
+// Must only be called once for the lifetime of igmp.
+func (igmp *igmpState) init(ep *endpoint) {
+ igmp.ep = ep
+ igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: igmp,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv4AllSystems,
+ })
+ igmp.igmpV1Present = igmpV1PresentDefault
+ igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
+ igmp.setV1Present(false)
+ })
+}
+
+// handleIGMP handles an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
+ stats := igmp.ep.protocol.stack.Stats()
+ received := stats.IGMP.PacketsReceived
+ headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+ h := header.IGMP(headerView)
+
+ // Temporarily reset the checksum field to 0 in order to calculate the proper
+ // checksum.
+ wantChecksum := h.Checksum()
+ h.SetChecksum(0)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(wantChecksum)
+
+ if gotChecksum != wantChecksum {
+ received.ChecksumErrors.Increment()
+ return
+ }
+
+ switch h.Type() {
+ case header.IGMPMembershipQuery:
+ received.MembershipQuery.Increment()
+ if len(headerView) < header.IGMPQueryMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
+ case header.IGMPv1MembershipReport:
+ received.V1MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPv2MembershipReport:
+ received.V2MembershipReport.Increment()
+ if len(headerView) < header.IGMPReportMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+ igmp.handleMembershipReport(h.GroupAddress())
+ case header.IGMPLeaveGroup:
+ received.LeaveGroup.Increment()
+ // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
+ // Report, are ignored in all states"
+
+ default:
+ // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
+ // be silently ignored. New message types may be used by newer versions of
+ // IGMP, by multicast routing protocols, or other uses."
+ received.Unrecognized.Increment()
+ }
+}
+
+func (igmp *igmpState) v1Present() bool {
+ return atomic.LoadUint32(&igmp.igmpV1Present) == 1
+}
+
+func (igmp *igmpState) setV1Present(v bool) {
+ if v {
+ atomic.StoreUint32(&igmp.igmpV1Present, 1)
+ } else {
+ atomic.StoreUint32(&igmp.igmpV1Present, 0)
+ }
+}
+
+// handleMembershipQuery handles a membership query.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
+ // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
+ // then change the state to note that an IGMPv1 router is present and
+ // schedule the query received Job.
+ if maxRespTime == 0 && igmp.Enabled() {
+ igmp.igmpV1Job.Cancel()
+ igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
+ igmp.setV1Present(true)
+ maxRespTime = v1MaxRespTime
+ }
+
+ igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime)
+}
+
+// handleMembershipReport handles a membership report.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
+ igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
+}
+
+// writePacket assembles and sends an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
+ igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
+ igmpData.SetType(igmpType)
+ igmpData.SetGroupAddress(groupAddress)
+ igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()),
+ Data: buffer.View(igmpData).ToVectorisedView(),
+ })
+
+ addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return false, nil
+ }
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
+ addressEndpoint = nil
+ igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.IGMPProtocolNumber,
+ TTL: header.IGMPTTL,
+ TOS: stack.DefaultTOS,
+ }, header.IPv4OptionsSerializer{
+ &header.IPv4SerializableRouterAlertOption{},
+ })
+
+ sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
+ }
+ switch igmpType {
+ case header.IGMPv1MembershipReport:
+ sentStats.V1MembershipReport.Increment()
+ case header.IGMPv2MembershipReport:
+ sentStats.V2MembershipReport.Increment()
+ case header.IGMPLeaveGroup:
+ sentStats.LeaveGroup.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
+ }
+ return true, nil
+}
+
+// joinGroup handles adding a new group to the membership map, setting up the
+// IGMP state for the group, and sending and scheduling the required
+// messages.
+//
+// If the group already exists in the membership map, returns
+// tcpip.ErrDuplicateAddress.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
+ return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Leave Group message
+// if required.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of IGMP, but remains
+// joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) softLeaveAll() {
+ igmp.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the IGMP state for each group that has
+// been joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) initializeAll() {
+ igmp.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: igmp.ep.mu must be locked.
+func (igmp *igmpState) sendQueuedReports() {
+ igmp.genericMulticastProtocol.SendQueuedReportsLocked()
+}
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
new file mode 100644
index 000000000..1ee573ac8
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -0,0 +1,215 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ addr = tcpip.Address("\x0a\x00\x00\x01")
+ multicastAddr = tcpip.Address("\xe0\x00\x00\x03")
+ nicID = 1
+)
+
+// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set. Raises a t.Error if
+// any field does not match.
+func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.SrcAddr(addr),
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(igmpType),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ // Create an endpoint of queue size 1, since no more than 1 packets are ever
+ // queued in the tests in this file.
+ e := channel.New(1, 1280, linkAddr)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ return e, s, clock
+}
+
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: 1,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(igmpType)
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is
+// present on the network. The IGMP stack will then send IGMPv1 Membership
+// reports for backwards compatibility.
+func TestIgmpV1Present(t *testing.T) {
+ e, s, clock := createStack(t, true)
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
+ }
+
+ // This NIC will send an IGMPv2 report immediately, before this test can get
+ // the IGMPv1 General Membership Query in.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Inject an IGMPv1 General Membership Query which is identical to a standard
+ // membership query except the Max Response Time is set to 0, which will tell
+ // the stack that this is a router using IGMPv1. Send it to the all systems
+ // group which is the only group this host belongs to.
+ createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems)
+ if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
+ t.Fatalf("got Membership Queries received = %d, want = 1", got)
+ }
+
+ // Before advancing the clock, verify that this host has not sent a
+ // V1MembershipReport yet.
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
+ }
+
+ // Verify the solicited Membership Report is sent. Now that this NIC has seen
+ // an IGMPv1 query, it should send an IGMPv1 Membership Report.
+ p, ok = e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ p, ok = e.Read()
+ if !ok {
+ t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
+ }
+ if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
+ t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
+ }
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr)
+}
+
+func TestSendQueuedIGMPReports(t *testing.T) {
+ e, s, clock := createStack(t, true)
+
+ // Joining a group without an assigned address should queue IGMP packets; none
+ // should be sent without an assigned address.
+ if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
+ }
+ reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
+ if got := reportStat.Value(); got != 0 {
+ t.Errorf("got reportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+
+ // The initial set of IGMP reports that were queued should be sent once an
+ // address is assigned.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ }
+ if got := reportStat.Value(); got != 1 {
+ t.Errorf("got reportStat.Value() = %d, want = 1", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ clock.Advance(ipv4.UnsolicitedReportIntervalMax)
+ if got := reportStat.Value(); got != 2 {
+ t.Errorf("got reportStat.Value() = %d, want = 2", got)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Error("expected to send an IGMP membership report")
+ } else {
+ validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should have no more packets to send after the initial set of unsolicited
+ // reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got unexpected packet = %#v", p)
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a376cb8ec..e9ff70d04 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -83,6 +83,7 @@ type endpoint struct {
sync.RWMutex
addressableEndpointState stack.AddressableEndpointState
+ igmp igmpState
}
}
@@ -93,7 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
+ e.mu.igmp.init(e)
+ e.mu.Unlock()
return e
}
@@ -121,11 +125,22 @@ func (e *endpoint) Enable() *tcpip.Error {
// We have no need for the address endpoint.
ep.DecRef()
+ // Groups may have been joined while the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of IGMP when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.igmp.initializeAll()
+
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
// all-systems multicast group.
- _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
- return err
+ if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err))
+ }
+
+ return nil
}
// Enabled implements stack.NetworkEndpoint.
@@ -157,19 +172,27 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.isEnabled() {
return
}
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
+ // Leave groups from the perspective of IGMP so that routers know that
+ // we are no longer interested in the group.
+ e.mu.igmp.softLeaveAll()
+
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -198,37 +221,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) {
hdrLen := header.IPv4MinimumSize
- var opts header.IPv4Options
- if params.Options != nil {
- var ok bool
- if opts, ok = params.Options.(header.IPv4Options); !ok {
- panic(fmt.Sprintf("want IPv4Options, got %T", params.Options))
- }
- hdrLen += opts.AllocationSize()
- if hdrLen > header.IPv4MaximumHeaderSize {
- // Since we have no way to report an error we must either panic or create
- // a packet which is different to what was requested. Choose panic as this
- // would be a programming error that should be caught in testing.
- panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.AllocationSize(), header.IPv4MaximumOptionsSize))
- }
+ var optLen int
+ if options != nil {
+ optLen = int(options.Length())
+ }
+ hdrLen += optLen
+ if hdrLen > header.IPv4MaximumHeaderSize {
+ // Since we have no way to report an error we must either panic or create
+ // a packet which is different to what was requested. Choose panic as this
+ // would be a programming error that should be caught in testing.
+ panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize))
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := uint16(pkt.Size())
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
// are non-atomic and need an ID.
- id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
TotalLength: length,
ID: uint16(id),
TTL: params.TTL,
TOS: params.TOS,
Protocol: uint8(params.Protocol),
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- Options: opts,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ Options: options,
})
ip.SetChecksum(^ip.CalculateChecksum())
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -259,17 +279,14 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt)
-}
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -286,24 +303,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -347,7 +367,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.addIPHeader(r, pkt, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */)
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
@@ -374,8 +394,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -400,9 +419,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -461,7 +481,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// non-atomic datagrams, so assign an ID to all such datagrams
// according to the definition given in RFC 6864 section 4.
if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
- ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
}
@@ -479,16 +499,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt)
+ return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ ttl := h.TTL()
+ if ttl == 0 {
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ }
+
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 791 page 30, Time to Live,
+ //
+ // This field must be decreased at each point that the internet header
+ // is processed to reflect the time spent processing the datagram.
+ // Even if no local information is available on the time actually
+ // spent, the field must be decremented by 1.
+ newHdr.SetTTL(ttl - 1)
+
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(newHdr).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -524,19 +613,46 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
// As per RFC 1122 section 3.2.1.3:
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
+ // Make sure the source address is not a subnet-local broadcast address.
+ if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ if subnet.IsBroadcast(srcAddr) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
+
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ addressEndpoint.DecRef()
+ pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
+ } else if !e.IsInGroup(dstAddr) {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
@@ -565,29 +681,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Set up a callback in case we need to send a Time Exceeded Message, as per
- // RFC 792:
- //
- // If a host reassembling a fragmented datagram cannot complete the
- // reassembly due to missing fragments within its time limit it discards
- // the datagram, and it may send a time exceeded message.
- //
- // If fragment zero is not available then no time exceeded need be sent at
- // all.
- var releaseCB func(bool)
- if start == 0 {
- pkt := pkt.Clone()
- releaseCB = func(timedOut bool) {
- if timedOut {
- _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
- }
- }
- }
-
- var ready bool
- var err error
proto := h.Protocol()
- pkt.Data, _, ready, err = e.protocol.fragmentation.Process(
+ data, _, ready, err := e.protocol.fragmentation.Process(
// As per RFC 791 section 2.3, the identification value is unique
// for a source-destination pair and protocol.
fragmentation.FragmentID{
@@ -600,8 +695,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
start+uint16(pkt.Data.Size())-1,
h.More(),
proto,
- pkt.Data,
- releaseCB,
+ pkt,
)
if err != nil {
stats.IP.MalformedPacketsReceived.Increment()
@@ -611,6 +705,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !ready {
return
}
+ pkt.Data = data
// The reassembler doesn't take care of fixing up the header, so we need
// to do it here.
@@ -628,11 +723,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.handleICMP(pkt)
return
}
- if len(h.Options()) != 0 {
+ if p == header.IGMPProtocolNumber {
+ e.mu.Lock()
+ e.mu.igmp.handleIGMP(pkt)
+ e.mu.Unlock()
+ return
+ }
+ if opts := h.Options(); len(opts) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
// rather than just throwing them away.
- aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{})
+ aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{})
if err != nil {
switch {
case
@@ -683,7 +784,12 @@ func (e *endpoint) Close() {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err == nil {
+ e.mu.igmp.sendQueuedReports()
+ }
+ return ep, err
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
@@ -706,34 +812,26 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
defer e.mu.Unlock()
loopback := e.nic.IsLoopback()
- addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
- })
- if addressEndpoint != nil {
- return addressEndpoint
- }
-
- if !allowTemp {
- return nil
- }
-
- addr := localAddr.WithPrefix()
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
- if err != nil {
- // AddAddress only returns an error if the address is already assigned,
- // but we just checked above if the address exists so we expect no error.
- panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
- }
- return addressEndpoint
+ }, allowTemp, tempPEB)
}
// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
e.mu.RLock()
defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements
+//
+// Precondition: igmp.ep.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
}
@@ -752,32 +850,48 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mu.igmp.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
@@ -798,6 +912,8 @@ type protocol struct {
hashIV uint32
fragmentation *fragmentation.Fragmentation
+
+ options Options
}
// Number returns the ipv4 protocol number.
@@ -922,17 +1038,23 @@ func addressToUint32(addr tcpip.Address) uint32 {
return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
}
-// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number and a 32-bit number to
-// generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- a := addressToUint32(r.LocalAddress)
- b := addressToUint32(r.RemoteAddress)
+// hashRoute calculates a hash value for the given source/destination pair using
+// the addresses, transport protocol number and a 32-bit number to generate the
+// hash.
+func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
+ a := addressToUint32(srcAddr)
+ b := addressToUint32(dstAddr)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+// Options holds options to configure a new protocol.
+type Options struct {
+ // IGMP holds options for IGMP.
+ IGMP IGMPOptions
+}
+
+// NewProtocolWithOptions returns an IPv4 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -942,15 +1064,24 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{
- stack: s,
- ids: ids,
- hashIV: hashIV,
- defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ options: opts,
+ }
+ p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ return p
}
}
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
+}
+
func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
fragPkt, offset, copied, more := pf.BuildNextFragment()
fragPkt.NetworkProtocolNumber = ProtocolNumber
@@ -1063,6 +1194,12 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
}
pointer := tsOpt.Pointer()
+ // RFC 791 page 22 states: "The smallest legal value is 5."
+ // Since the pointer is 1 based, and the header is 4 bytes long the
+ // pointer must point beyond the header therefore 4 or less is bad.
+ if pointer <= header.IPv4OptionTimestampHdrLength {
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
// To simplify processing below, base further work on the array of timestamps
// beyond the header, rather than on the whole option. Also to aid
// calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
@@ -1149,7 +1286,15 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
}
- nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+ pointer := rrOpt.Pointer()
+ // RFC 791 page 20 states:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ // Since the pointer is 1 based, and the header is 3 bytes long the
+ // pointer must point beyond the header therefore 3 or less is bad.
+ if pointer <= header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
// RFC 791 page 21 says
// If the route data area is already full (the pointer exceeds the
@@ -1164,14 +1309,14 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// do this (as do most implementations). It is probable that the inclusion
// of these words is a copy/paste error from the timestamp option where
// there are two failure reasons given.
- if nextSlot >= optlen {
+ if pointer > optlen {
return 0, nil
}
// The data area isn't full but there isn't room for a new entry.
// Either Length or Pointer could be bad. We must select Pointer for Linux
- // compatibility, even if only the length is bad.
- if nextSlot+header.IPv4AddressSize > optlen {
+ // compatibility, even if only the length is bad. NB. pointer is 1 based.
+ if pointer+header.IPv4AddressSize > optlen+1 {
if false {
// This is what we would do if we were not being Linux compatible.
// Check for bad pointer or length value. Must be a multiple of 4 after
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index c6e565455..1c4919b1e 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,9 +15,11 @@
package ipv4_test
import (
+ "bytes"
"context"
"encoding/hex"
"fmt"
+ "io/ioutil"
"math"
"net"
"testing"
@@ -103,6 +105,163 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv4Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ ipv4Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
+ PrefixLen: 8,
+ }
+ remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
+ remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, ipv4.MaxTotalSize, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv4Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv4Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
+ }
+
+ totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: totalLen,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr1,
+ DstAddr: remoteIPv4Addr2,
+ })
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv4Addr1.Address),
+ checker.DstAddr(remoteIPv4Addr1),
+ checker.TTL(ipv4.DefaultTTL),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4TTLExceeded),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv4Addr1),
+ checker.DstAddr(remoteIPv4Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}
+
// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
// checks the response.
func TestIPv4Sanity(t *testing.T) {
@@ -319,7 +478,7 @@ func TestIPv4Sanity(t *testing.T) {
68, 7, 5, 0,
// ^ ^ Linux points here which is wrong.
// | Not a multiple of 4
- 1, 2, 3,
+ 1, 2, 3, 0,
},
shouldFail: true,
expectErrorICMP: true,
@@ -398,6 +557,56 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
+ // Timestamp pointer uses one based counting so 0 is invalid.
+ name: "timestamp pointer invalid",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, 0, 0x00,
+ // ^ 0 instead of 5 or more.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Timestamp pointer cannot be less than 5. It must point past the header
+ // which is 4 bytes. (1 based counting)
+ name: "timestamp pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength, 0x00,
+ // ^ header is 4 bytes, so 4 should fail.
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "valid timestamp pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00,
+ // ^ header is 4 bytes, so 5 should succeed.
+ 0, 0, 0, 0,
+ },
+ replyOptions: header.IPv4Options{
+ 68, 8, 9, 0x00,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
// Needs 8 bytes for a type 1 timestamp but there are only 4 free.
name: "bad timer element alignment",
maxTotalLength: ipv4.MaxTotalSize,
@@ -528,7 +737,61 @@ func TestIPv4Sanity(t *testing.T) {
},
},
{
- // Confirm linux bug for bug compatibility.
+ // Pointer uses one based counting so 0 is invalid.
+ name: "record route pointer zero",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, 0, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. 3 should fail.
+ name: "record route pointer too small by 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ // Pointer must be 4 or more as it must point past the 3 byte header
+ // using 1 based counting. Check 4 passes. (Duplicates "single
+ // record route with room")
+ name: "valid record route pointer",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: header.IPv4Options{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm Linux bug for bug compatibility.
// Linux returns slot 22 but the error is in slot 21.
name: "multiple record route with not enough room",
maxTotalLength: ipv4.MaxTotalSize,
@@ -599,9 +862,12 @@ func TestIPv4Sanity(t *testing.T) {
},
})
- ipHeaderLength := header.IPv4MinimumSize + test.options.AllocationSize()
+ if len(test.options)%4 != 0 {
+ t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options))
+ }
+ ipHeaderLength := header.IPv4MinimumSize + len(test.options)
if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(totalLen))
@@ -618,16 +884,26 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
+
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
Protocol: test.transportProtocol,
TTL: test.TTL,
SrcAddr: remoteIPv4Addr,
DstAddr: ipv4Addr.Address,
- Options: test.options,
})
if test.headerLength != 0 {
ip.SetHeaderLength(test.headerLength)
+ } else {
+ // Set the calculated header length, since we may manually add options.
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ }
+ if len(test.options) != 0 {
+ // Copy options manually. We do not use Encode for options so we can
+ // verify malformed options with handcrafted payloads.
+ if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
+ t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
+ }
}
ip.SetChecksum(0)
ipHeaderChecksum := ip.CalculateChecksum()
@@ -2049,6 +2325,28 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
+ {
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: nil,
+ },
}
for _, test := range tests {
@@ -2112,18 +2410,26 @@ func TestReceiveFragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
+ const rcvSize = 65536 // Account for reassembled packets.
for i, expectedPayload := range test.expectedPayloads {
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ t.Fatalf("(i=%d) Read: %s", i, err)
}
- if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(expectedPayload),
+ Total: len(expectedPayload),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff)
+ }
+ if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" {
t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
}
}
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
@@ -2242,7 +2548,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2259,7 +2565,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
@@ -2441,9 +2747,6 @@ func TestPacketQueing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 0ac24a6fb..afa45aefe 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -8,6 +8,7 @@ go_library(
"dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "mld.go",
"ndp.go",
],
visibility = ["//visibility:public"],
@@ -19,6 +20,7 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
+ "//pkg/tcpip/network/ip",
"//pkg/tcpip/stack",
],
)
@@ -49,3 +51,19 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "ipv6_x_test",
+ size = "small",
+ srcs = ["mld_test.go"],
+ deps = [
+ ":ipv6",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 8502b848c..6ee162713 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := e.protocol.stack.Stats().ICMP
- sent := stats.V6PacketsSent
- received := stats.V6PacketsReceived
+ sent := stats.V6.PacketsSent
+ received := stats.V6.PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
@@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
- switch h.Type() {
+ switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
@@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na := header.NDPNeighborAdvert(packet.MessageBody())
// As per RFC 4861 section 7.2.4:
//
@@ -644,8 +644,39 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
return
}
+ case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ received.MulticastListenerQuery.Increment()
+ case header.ICMPv6MulticastListenerReport:
+ received.MulticastListenerReport.Increment()
+ case header.ICMPv6MulticastListenerDone:
+ received.MulticastListenerDone.Increment()
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerReport:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerDone:
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
+ }
+
default:
- received.Invalid.Increment()
+ received.Unrecognized.Increment()
}
}
@@ -681,12 +712,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- stat := p.stack.Stats().ICMP.V6PacketsSent
+ stat := p.stack.Stats().ICMP.V6.PacketsSent
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
@@ -750,6 +781,12 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in
+// transit to its final destination, as per RFC 4443 section 3.3.
+type icmpReasonHopLimitExceeded struct{}
+
+func (*icmpReasonHopLimitExceeded) isICMPReason() {}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -790,29 +827,49 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
allowResponseToMulticast = reason.respondToMulticast
}
- if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any {
+ isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst)
+ if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any {
return nil
}
+ // If we hit a Hop Limit Exceeded error, then we know we are operating as a
+ // router. As per RFC 4443 section 3.3:
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ //
+ // If we are operating as a router, do not use the packet's destination
+ // address as the response's source address as we should not own the
+ // destination address of a packet we are forwarding.
+ //
+ // If the packet was originally destined to a multicast address, then do not
+ // use the packet's destination address as the source for the response ICMP
+ // packet as "multicast addresses must not be used as source addresses in IPv6
+ // packets", as per RFC 4291 section 2.7.
+ localAddr := origIPHdrDst
+ if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
+ localAddr = ""
+ }
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
stats := p.stack.Stats().ICMP
- sent := stats.V6PacketsSent
+ sent := stats.V6.PacketsSent
if !p.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return nil
}
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
// TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
// Unfortunately at this time ICMP Packets do not have a transport
@@ -830,6 +887,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
}
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
// As per RFC 4443 section 2.4
//
// (c) Every ICMPv6 error message (type < 128) MUST include
@@ -873,6 +932,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonHopLimitExceeded:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
+ counter = sent.TimeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
@@ -896,3 +959,16 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
counter.Increment()
return nil
}
+
+// OnReassemblyTimeout implements fragmentation.TimeoutHandler.
+func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
+ // OnReassemblyTimeout sends a Time Exceeded Message as per RFC 2460 Section
+ // 4.5:
+ //
+ // If the first fragment (i.e., the one with a Fragment Offset of zero) has
+ // been received, an ICMP Time Exceeded -- Fragment Reassembly Time Exceeded
+ // message should be sent to the source of that fragment.
+ if pkt != nil {
+ p.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 76013daa1..bbce1ef78 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -144,11 +144,14 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) Promiscuous() bool {
+ return false
+}
+
func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- r := stack.Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
- }
+ var r stack.Route
+ r.NetProto = protocol
+ r.ResolveWith(remoteLinkAddr)
return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
}
@@ -174,13 +177,8 @@ func TestICMPCounts(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: test.useNeighborCache,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -206,11 +204,16 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -267,6 +270,22 @@ func TestICMPCounts(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -276,13 +295,12 @@ func TestICMPCounts(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -290,7 +308,7 @@ func TestICMPCounts(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -298,7 +316,7 @@ func TestICMPCounts(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -317,13 +335,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
UseNeighborCache: true,
})
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -349,11 +362,16 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
var tllData [header.NDPLinkLayerAddressSize]byte
header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
@@ -410,6 +428,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
typ: header.ICMPv6RedirectMsg,
size: header.ICMPv6MinimumSize,
},
+ {
+ typ: header.ICMPv6MulticastListenerQuery,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerReport,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: header.ICMPv6MulticastListenerDone,
+ size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
+ },
+ {
+ typ: 255, /* Unrecognized */
+ size: 50,
+ },
}
handleIPv6Payload := func(icmp header.ICMPv6) {
@@ -419,13 +453,12 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -433,7 +466,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
handleIPv6Payload(icmp)
}
@@ -441,7 +474,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
if got, want := s.Value(), uint64(1); got != want {
t.Errorf("got %s = %d, want = %d", name, got, want)
@@ -566,7 +599,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
return
}
- if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
+ if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr {
t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
}
@@ -819,11 +852,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
@@ -831,7 +864,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -896,11 +929,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1014,11 +1047,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(icmpSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1026,7 +1059,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1074,11 +1107,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
errorICMPBody := func(view buffer.View) {
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
+ PayloadLength: simpleBodySize,
+ TransportProtocol: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
})
simpleBody(view[header.IPv6MinimumSize:])
}
@@ -1193,11 +1226,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(size + payloadSize),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
@@ -1205,7 +1238,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
e.InjectInbound(ProtocolNumber, pkt)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
typStat := typ.statCounter(stats)
@@ -1411,11 +1444,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1453,11 +1486,11 @@ func TestPacketQueing(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1502,7 +1535,7 @@ func TestPacketQueing(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
NIC: nicID,
},
@@ -1541,7 +1574,7 @@ func TestPacketQueing(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
@@ -1552,11 +1585,11 @@ func TestPacketQueing(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -1590,7 +1623,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1610,7 +1643,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1627,7 +1660,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
return icmp
},
@@ -1643,7 +1676,7 @@ func TestCallsToNeighborCache(t *testing.T) {
nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(nsSize))
icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(lladdr0)
ns.Options().Serialize(header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(linkAddr1),
@@ -1660,7 +1693,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1681,7 +1714,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1700,7 +1733,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1720,7 +1753,7 @@ func TestCallsToNeighborCache(t *testing.T) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
icmp := header.ICMPv6(buffer.NewView(naSize))
icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
na.SetSolicitedFlag(false)
na.SetOverrideFlag(false)
na.SetTargetAddress(lladdr1)
@@ -1775,30 +1808,31 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := lladdr0.WithPrefix()
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ ep.DecRef()
}
- defer r.Release()
-
- // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
- r.LocalAddress = test.destination
icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.IPv6MinimumSize,
Data: buffer.View(icmp).ToVectorisedView(),
})
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.RemoteAddress,
- DstAddr: r.LocalAddress,
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.source,
+ DstAddr: test.destination,
})
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 0526190cc..f2018d073 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -19,6 +19,7 @@ import (
"encoding/binary"
"fmt"
"hash/fnv"
+ "math"
"sort"
"sync/atomic"
"time"
@@ -34,7 +35,9 @@ import (
)
const (
+ // ReassembleTimeout controls how long a fragment will be held.
// As per RFC 8200 section 4.5:
+ //
// If insufficient fragments are received to complete reassembly of a packet
// within 60 seconds of the reception of the first-arriving fragment of that
// packet, reassembly of that packet must be abandoned.
@@ -58,6 +61,108 @@ const (
buckets = 2048
)
+// policyTable is the default policy table defined in RFC 6724 section 2.1.
+//
+// A more human-readable version:
+//
+// Prefix Precedence Label
+// ::1/128 50 0
+// ::/0 40 1
+// ::ffff:0:0/96 35 4
+// 2002::/16 30 2
+// 2001::/32 5 5
+// fc00::/7 3 13
+// ::/96 1 3
+// fec0::/10 1 11
+// 3ffe::/16 1 12
+//
+// The table is sorted by prefix length so longest-prefix match can be easily
+// achieved.
+//
+// We willingly left out ::/96, fec0::/10 and 3ffe::/16 since those prefix
+// assignments are deprecated.
+//
+// As per RFC 4291 section 2.5.5.1 (for ::/96),
+//
+// The "IPv4-Compatible IPv6 address" is now deprecated because the
+// current IPv6 transition mechanisms no longer use these addresses.
+// New or updated implementations are not required to support this
+// address type.
+//
+// As per RFC 3879 section 4 (for fec0::/10),
+//
+// This document formally deprecates the IPv6 site-local unicast prefix
+// defined in [RFC3513], i.e., 1111111011 binary or FEC0::/10.
+//
+// As per RFC 3701 section 1 (for 3ffe::/16),
+//
+// As clearly stated in [TEST-NEW], the addresses for the 6bone are
+// temporary and will be reclaimed in the future. It further states
+// that all users of these addresses (within the 3FFE::/16 prefix) will
+// be required to renumber at some time in the future.
+//
+// and section 2,
+//
+// Thus after the pTLA allocation cutoff date January 1, 2004, it is
+// REQUIRED that no new 6bone 3FFE pTLAs be allocated.
+//
+// MUST NOT BE MODIFIED.
+var policyTable = [...]struct {
+ subnet tcpip.Subnet
+
+ label uint8
+}{
+ // ::1/128
+ {
+ subnet: header.IPv6Loopback.WithPrefix().Subnet(),
+ label: 0,
+ },
+ // ::ffff:0:0/96
+ {
+ subnet: header.IPv4MappedIPv6Subnet,
+ label: 4,
+ },
+ // 2001::/32 (Teredo prefix as per RFC 4380 section 2.6).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 32,
+ }.Subnet(),
+ label: 5,
+ },
+ // 2002::/16 (6to4 prefix as per RFC 3056 section 2).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 16,
+ }.Subnet(),
+ label: 2,
+ },
+ // fc00::/7 (Unique local addresses as per RFC 4193 section 3.1).
+ {
+ subnet: tcpip.AddressWithPrefix{
+ Address: "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 7,
+ }.Subnet(),
+ label: 13,
+ },
+ // ::/0
+ {
+ subnet: header.IPv6EmptySubnet,
+ label: 1,
+ },
+}
+
+func getLabel(addr tcpip.Address) uint8 {
+ for _, p := range policyTable {
+ if p.subnet.Contains(addr) {
+ return p.label
+ }
+ }
+
+ panic(fmt.Sprintf("should have a label for address = %s", addr))
+}
+
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -83,6 +188,7 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
+ mld mldState
}
}
@@ -118,6 +224,45 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// onAddressAssignedLocked handles an address being assigned.
+//
+// Precondition: e.mu must be exclusively locked.
+func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, ...
+ //
+ // If we just completed DAD for a link-local address, then attempt to send any
+ // queued MLD reports. Note, we may have sent reports already for some of the
+ // groups before we had a valid link-local address to use as the source for
+ // the MLD messages, but that was only so that MLD snooping switches are aware
+ // of our membership to groups - routers would not have handled those reports.
+ //
+ // As per RFC 3590 section 4,
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ if header.IsV6LinkLocalAddress(addr) {
+ e.mu.mld.sendQueuedReports()
+ }
+}
+
// InvalidateDefaultRouter implements stack.NDPEndpoint.
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
@@ -224,6 +369,12 @@ func (e *endpoint) Enable() *tcpip.Error {
return nil
}
+ // Groups may have been joined when the endpoint was disabled, or the
+ // endpoint may have left groups from the perspective of MLD when the
+ // endpoint was disabled. Either way, we need to let routers know to
+ // send us multicast traffic.
+ e.mu.mld.initializeAll()
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -241,8 +392,10 @@ func (e *endpoint) Enable() *tcpip.Error {
// (NDP NS) messages may be sent to the All-Nodes multicast group if the
// source address of the NDP NS is the unspecified address, as per RFC 4861
// section 7.2.4.
- if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
- return err
+ if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err))
}
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
@@ -251,7 +404,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
var err *tcpip.Error
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
return true
@@ -273,7 +426,7 @@ func (e *endpoint) Enable() *tcpip.Error {
}
// Do not auto-generate an IPv6 link-local address for loopback devices.
- if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
// link-local address.
e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
@@ -322,7 +475,7 @@ func (e *endpoint) Disable() {
}
func (e *endpoint) disableLocked() {
- if !e.setEnabled(false) {
+ if !e.Enabled() {
return
}
@@ -331,9 +484,17 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
+
+ // Leave groups from the perspective of MLD so that routers know that
+ // we are no longer interested in the group.
+ e.mu.mld.softLeaveAll()
+
+ if !e.setEnabled(false) {
+ panic("should have only done work to disable the endpoint if it was enabled")
+ }
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -341,7 +502,7 @@ func (e *endpoint) disableLocked() {
// Precondition: e.mu must be write locked.
func (e *endpoint) stopDADForPermanentAddressesLocked() {
// Stop DAD for all the tentative unicast addresses.
- e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.GetKind() != stack.PermanentTentative {
return true
}
@@ -373,19 +534,27 @@ func (e *endpoint) MTU() uint32 {
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
+ // TODO(gvisor.dev/issues/5035): The maximum header length returned here does
+ // not open the possibility for the caller to know about size required for
+ // extension headers.
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) {
+ extHdrsLen := extensionHeaders.Length()
+ length := pkt.Size() + extensionHeaders.Length()
+ if length > math.MaxUint16 {
+ panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16))
+ }
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(params.Protocol),
- HopLimit: params.TTL,
- TrafficClass: params.TOS,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ PayloadLength: uint16(length),
+ TransportProtocol: params.Protocol,
+ HopLimit: params.TTL,
+ TrafficClass: params.TOS,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ ExtensionHeaders: extensionHeaders,
})
pkt.NetworkProtocolNumber = ProtocolNumber
}
@@ -440,18 +609,14 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- e.addIPHeader(r, pkt, params)
- return e.writePacket(r, gso, pkt, params.Protocol)
-}
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */)
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesOutputDropped.Increment()
+ e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
@@ -467,24 +632,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
// Since we rewrote the packet but it is being routed back to us, we can
// safely assume the checksum is valid.
pkt.RXTransportChecksumValidated = true
- ep.HandlePacket(pkt)
+ ep.(*endpoint).handlePacket(pkt)
}
return nil
}
}
+ return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
+}
+
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- loopedR := r.MakeLoopedRoute()
- loopedR.PopulatePacketInfo(pkt)
- loopedR.Release()
- e.HandlePacket(pkt)
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = !headerIncluded
+ e.handlePacket(pkt)
}
}
if r.Loop&stack.PacketOut == 0 {
@@ -530,7 +698,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- e.addIPHeader(r, pb, params)
+ e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */)
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
@@ -558,8 +726,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
@@ -584,9 +751,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- route.PopulatePacketInfo(pkt)
- ep.HandlePacket(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.(*endpoint).handlePacket(pkt)
}
n++
continue
@@ -640,16 +808,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrMalformedHeader
}
- return e.writePacket(r, nil /* gso */, pkt, proto)
+ return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
+}
+
+// forwardPacket attempts to forward a packet to its final destination.
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+ h := header.IPv6(pkt.NetworkHeader().View())
+ hopLimit := h.HopLimit()
+ if hopLimit <= 1 {
+ // As per RFC 4443 section 3.3,
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard the
+ // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
+ // the source of the packet. This indicates either a routing loop or
+ // too small an initial Hop Limit value.
+ return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ }
+
+ dstAddr := h.DestinationAddress()
+
+ // Check if the destination is owned by the stack.
+ networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
+ if err == nil {
+ networkEndpoint.(*endpoint).handlePacket(pkt)
+ return nil
+ }
+ if err != tcpip.ErrBadAddress {
+ return err
+ }
+
+ r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // We need to do a deep copy of the IP packet because
+ // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
+ // not own it.
+ newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+
+ // As per RFC 8200 section 3,
+ //
+ // Hop Limit 8-bit unsigned integer. Decremented by 1 by
+ // each node that forwards the packet.
+ newHdr.SetHopLimit(hopLimit - 1)
+
+ return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(newHdr).ToVectorisedView(),
+ }))
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
+ stats.IP.PacketsReceived.Increment()
+
if !e.isEnabled() {
+ stats.IP.DisabledPacketsReceived.Increment()
return
}
+ // Loopback traffic skips the prerouting chain.
+ if !e.nic.IsLoopback() {
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IP.IPTablesPreroutingDropped.Increment()
+ return
+ }
+ }
+
+ e.handlePacket(pkt)
+}
+
+// handlePacket is like HandlePacket except it does not perform the prerouting
+// iptables hook.
+func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.protocol.stack.Stats()
@@ -669,6 +906,20 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ // The destination address should be an address we own or a group we joined
+ // for us to receive the packet. Otherwise, attempt to forward the packet.
+ if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ } else if !e.IsInGroup(dstAddr) {
+ if !e.protocol.Forwarding() {
+ stats.IP.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+
+ _ = e.forwardPacket(pkt)
+ return
+ }
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -681,8 +932,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
stats.IP.IPTablesInputDropped.Increment()
return
@@ -888,18 +1138,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Set up a callback in case we need to send a Time Exceeded Message as
- // per RFC 2460 Section 4.5.
- var releaseCB func(bool)
- if start == 0 {
- pkt := pkt.Clone()
- releaseCB = func(timedOut bool) {
- if timedOut {
- _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
- }
- }
- }
-
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
data, proto, ready, err := e.protocol.fragmentation.Process(
@@ -914,17 +1152,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
start+uint16(fragmentPayloadLen)-1,
extHdr.More(),
uint8(rawPayload.Identifier),
- rawPayload.Buf,
- releaseCB,
+ pkt,
)
if err != nil {
stats.IP.MalformedPacketsReceived.Increment()
stats.IP.MalformedFragmentsReceived.Increment()
return
}
- pkt.Data = data
if ready {
+ pkt.Data = data
+
// We create a new iterator with the reassembled packet because we could
// have more extension headers in the reassembled payload, as per RFC
// 8200 section 4.5. We also use the NextHeader value from the first
@@ -1023,9 +1261,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
+ // The location of the Next Header field is in a different place in
+ // the initial IPv6 header than it is in the extension headers so
+ // treat it specially.
+ prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset)
+ if previousHeaderStart != 0 {
+ prevHdrIDOffset = previousHeaderStart
+ }
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
+ pointer: prevHdrIDOffset,
}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
@@ -1033,12 +1278,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
default:
- _ = e.protocol.returnError(&icmpReasonParameterProblem{
- code: header.ICMPv6UnknownHeader,
- pointer: it.ParseOffset(),
- }, pkt)
- stats.UnknownProtocolRcvdPackets.Increment()
- return
+ // Since the iterator returns IPv6RawPayloadHeader for unknown Extension
+ // Header IDs this should never happen unless we missed a supported type
+ // here.
+ panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr))
+
}
}
}
@@ -1086,11 +1330,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
return addressEndpoint, nil
}
- snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
- return nil, err
- }
-
addressEndpoint.SetKind(stack.PermanentTentative)
if e.Enabled() {
@@ -1099,6 +1338,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
}
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := e.joinGroupLocked(snmc); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a valid
+ // IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err))
+ }
+
return addressEndpoint, nil
}
@@ -1144,7 +1390,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
- if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ // The endpoint may have already left the multicast group.
+ if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
@@ -1167,7 +1414,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
//
// Precondition: e.mu must be read or write locked.
func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
- return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+ return e.mu.addressableEndpointState.GetAddress(localAddr)
}
// MainAddress implements stack.AddressableEndpoint.
@@ -1199,6 +1446,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow
return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
}
+// getLinkLocalAddressRLocked returns a link-local address from the primary list
+// of addresses, if one is available.
+//
+// See stack.PrimaryEndpointBehavior for more details about the primary list.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
+ var linkLocalAddr tcpip.Address
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.IsAssigned(false /* allowExpired */) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ linkLocalAddr = addr
+ return false
+ }
+ }
+ return true
+ })
+ return linkLocalAddr
+}
+
// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
// but with locking requirements.
//
@@ -1208,7 +1475,11 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// RFC 6724 section 5.
type addrCandidate struct {
addressEndpoint stack.AddressEndpoint
+ addr tcpip.Address
scope header.IPv6AddressScope
+
+ label uint8
+ matchingPrefix uint8
}
if len(remoteAddr) == 0 {
@@ -1218,10 +1489,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
// Create a candidate set of available addresses we can potentially use as a
// source address.
var cs []addrCandidate
- e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
// If r is not valid for outgoing connections, it is not a valid endpoint.
if !addressEndpoint.IsAssigned(allowExpired) {
- return
+ return true
}
addr := addressEndpoint.AddressWithPrefix().Address
@@ -1235,8 +1506,13 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
cs = append(cs, addrCandidate{
addressEndpoint: addressEndpoint,
+ addr: addr,
scope: scope,
+ label: getLabel(addr),
+ matchingPrefix: remoteAddr.MatchingPrefix(addr),
})
+
+ return true
})
remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
@@ -1245,18 +1521,20 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
}
+ remoteLabel := getLabel(remoteAddr)
+
// Sort the addresses as per RFC 6724 section 5 rules 1-3.
//
- // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ // TODO(b/146021396): Implement rules 4, 5 of RFC 6724 section 5.
sort.Slice(cs, func(i, j int) bool {
sa := cs[i]
sb := cs[j]
// Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sa.addr == remoteAddr {
return true
}
- if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ if sb.addr == remoteAddr {
return false
}
@@ -1273,11 +1551,29 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address
return sbDep
}
+ // Prefer matching label as per RFC 6724 section 5 rule 6.
+ if sa, sb := sa.label == remoteLabel, sb.label == remoteLabel; sa != sb {
+ if sa {
+ return true
+ }
+ if sb {
+ return false
+ }
+ }
+
// Prefer temporary addresses as per RFC 6724 section 5 rule 7.
if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
return saTemp
}
+ // Use longest matching prefix as per RFC 6724 section 5 rule 8.
+ if sa.matchingPrefix > sb.matchingPrefix {
+ return true
+ }
+ if sb.matchingPrefix > sa.matchingPrefix {
+ return false
+ }
+
// sa and sb are equal, return the endpoint that is closest to the front of
// the primary endpoint list.
return i < j
@@ -1309,35 +1605,52 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.joinGroupLocked(addr)
+}
+
+// joinGroupLocked is like JoinGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return false, tcpip.ErrBadAddress
+ return tcpip.ErrBadAddress
}
- e.mu.Lock()
- defer e.mu.Unlock()
- return e.mu.addressableEndpointState.JoinGroup(addr)
+ e.mu.mld.joinGroup(addr)
+ return nil
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- return e.mu.addressableEndpointState.LeaveGroup(addr)
+ return e.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked is like LeaveGroup but with locking requirements.
+//
+// Precondition: e.mu must be locked.
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+ return e.mu.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mu.addressableEndpointState.IsInGroup(addr)
+ return e.mu.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
- stack *stack.Stack
+ stack *stack.Stack
+ options Options
mu struct {
sync.RWMutex
@@ -1361,26 +1674,6 @@ type protocol struct {
forwarding uint32
fragmentation *fragmentation.Fragmentation
-
- // ndpDisp is the NDP event dispatcher that is used to send the netstack
- // integrator NDP related events.
- ndpDisp NDPDispatcher
-
- // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
- ndpConfigs NDPConfigurations
-
- // opaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
- opaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
- // tempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- tempIIDSeed []byte
-
- // autoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
- autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -1413,16 +1706,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp = ndpState{
- ep: e,
- configs: p.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- e.mu.ndp.initializeTempAddrState()
+ e.mu.ndp.init(e)
+ e.mu.mld.init(e)
+ e.mu.Unlock()
p.mu.Lock()
defer p.mu.Unlock()
@@ -1545,17 +1833,17 @@ type Options struct {
// NDPConfigs is the default NDP configurations used by interfaces.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // AutoGenLinkLocal determines whether or not the stack attempts to
+ // auto-generate a link-local address for newly enabled non-loopback
// NICs.
//
// Note, setting this to true does not mean that a link-local address is
// assigned right away, or at all. If Duplicate Address Detection is enabled,
// an address is only assigned if it successfully resolves. If it fails, no
- // further attempts are made to auto-generate an IPv6 link-local adddress.
+ // further attempts are made to auto-generate a link-local adddress.
//
// The generated link-local address follows RFC 4291 Appendix A guidelines.
- AutoGenIPv6LinkLocal bool
+ AutoGenLinkLocal bool
// NDPDisp is the NDP event dispatcher that an integrator can provide to
// receive NDP related events.
@@ -1579,6 +1867,9 @@ type Options struct {
// seed that is too small would reduce randomness and increase predictability,
// defeating the purpose of temporary SLAAC addresses.
TempIIDSeed []byte
+
+ // MLD holds options for MLD.
+ MLD MLDOptions
}
// NewProtocolWithOptions returns an IPv6 network protocol.
@@ -1590,17 +1881,13 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
- stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
- ids: ids,
- hashIV: hashIV,
-
- ndpDisp: opts.NDPDisp,
- ndpConfigs: opts.NDPConfigs,
- opaqueIIDOpts: opts.OpaqueIIDOpts,
- tempIIDSeed: opts.TempIIDSeed,
- autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ stack: s,
+ options: opts,
+
+ ids: ids,
+ hashIV: hashIV,
}
+ p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[*endpoint]struct{})
p.SetDefaultTTL(DefaultTTL)
return p
@@ -1644,24 +1931,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
fragPkt.NetworkProtocolNumber = ProtocolNumber
originalIPHeadersLength := len(originalIPHeaders)
- fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+
+ s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ M: more,
+ Identification: id,
+ }}
+
+ fragmentIPHeadersLength := originalIPHeadersLength + s.Length()
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
- fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
}
- fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
- fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
- fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
- fragmentHeader.Encode(&header.IPv6FragmentFields{
- M: more,
- FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
- Identification: id,
- NextHeader: uint8(transportProto),
- })
+ nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:])
+
+ fragmentIPHeaders.SetNextHeader(nextHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
return fragPkt, more
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1bfcdde25..360025b20 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,9 +15,12 @@
package ipv6
import (
+ "bytes"
"encoding/hex"
"fmt"
+ "io/ioutil"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
@@ -50,6 +53,7 @@ const (
fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+ unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier)
extraHeaderReserve = 50
)
@@ -67,18 +71,18 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
if got := stats.NeighborAdvert.Value(); got != want {
t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
@@ -125,11 +129,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -572,6 +576,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
expectICMP: false,
},
{
+ name: "unknown next header (first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, unknownHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6NextHeaderOffset,
+ },
+ {
+ name: "unknown next header (not first)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ unknownHdrID, 0,
+ 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
name: "destination with unknown option skippable action",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -754,11 +785,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
pointer: header.IPv6FixedHeaderSize,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
- shouldAccept: false,
- },
- {
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
@@ -820,13 +846,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
},
}
+ const mtu = header.IPv6MinimumMTU
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
+ e := channel.New(1, mtu, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -872,7 +899,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), udpPayload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
+
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength))
sum = header.Checksum(udpPayload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
@@ -883,16 +916,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- dstAddr := tcpip.Address(addr2)
- if test.multicast {
- dstAddr = header.IPv6AllNodesMulticastAddress
- }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
- NextHeader: ipv6NextHdr,
- HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: dstAddr,
+ // We're lying about transport protocol here to be able to generate
+ // raw extension headers from the test definitions.
+ TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr),
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -951,17 +982,24 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if got := stats.Value(); got != 1 {
t.Errorf("got UDP Rx Packets = %d, want = 1", got)
}
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, mtu, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("Read(nil): %s", err)
+ t.Fatalf("Read: %s", err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(udpPayload),
+ Total: len(udpPayload),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
}
- if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
// Should not have any more UDP packets.
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, mtu, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
@@ -981,9 +1019,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
- udpPayload3Length = 127
- udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
- fragmentExtHdrLen = 8
+ udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
+ udpMaximumSizeMinus15 = header.UDPMaximumSize - 15
+ fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
routingExtHdrLen = 8
@@ -1327,14 +1366,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+65520,
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[:65520],
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
},
),
},
@@ -1343,14 +1382,17 @@ func TestReceiveIPv6Fragments(t *testing.T) {
dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
[]buffer.View{
// Fragment extension header.
//
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+ // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ udpMaximumSizeMinus15 & 0xff,
+ 0, 0, 0, 1}),
- ipv6Payload4Addr1ToAddr2[65520:],
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
},
),
},
@@ -1358,6 +1400,47 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
},
{
+ name: "Two fragments with MF flag reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ udpMaximumSizeMinus15 >> 8,
+ (udpMaximumSizeMinus15 & 0xff) + 1,
+ 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -1876,10 +1959,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
+ // We're lying about transport protocol here so that we can generate
+ // raw extension headers for the tests.
+ TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr),
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
@@ -1894,18 +1979,20 @@ func TestReceiveIPv6Fragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
+ const rcvSize = 65536 // Account for reassembled packets.
for i, p := range test.expectedPayloads {
- gotPayload, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ _, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
if err != nil {
- t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ t.Fatalf("(i=%d) Read: %s", i, err)
}
- if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ if diff := cmp.Diff(p, buf.Bytes()); diff != "" {
t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
}
}
- if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
}
@@ -1924,7 +2011,7 @@ func TestInvalidIPv6Fragments(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -1943,14 +2030,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 9,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0 >> 3,
M: true,
Identification: ident,
@@ -1970,14 +2056,13 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
M: false,
Identification: ident,
@@ -2018,10 +2103,9 @@ func TestInvalidIPv6Fragments(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
-
- fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2083,7 +2167,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
type fragmentData struct {
ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6FragmentFields
+ ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
payload []byte
}
@@ -2097,14 +2181,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2119,14 +2202,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2135,14 +2217,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2157,14 +2238,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2179,14 +2259,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2195,14 +2274,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2217,14 +2295,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
fragments: []fragmentData{
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 8,
M: false,
Identification: ident,
@@ -2233,14 +2310,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
},
{
ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- NextHeader: header.IPv6FragmentHeader,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
},
- ipv6FragmentFields: header.IPv6FragmentFields{
- NextHeader: uint8(header.UDPProtocolNumber),
+ ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
FragmentOffset: 0,
M: true,
Identification: ident,
@@ -2279,10 +2355,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- ip.Encode(&f.ipv6Fields)
+ encodeArgs := f.ipv6Fields
+ encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
+ ip.Encode(&encodeArgs)
fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
- fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
vv.AppendView(f.payload)
@@ -2438,7 +2515,7 @@ func TestWriteStats(t *testing.T) {
test.setup(t, rt.Stack())
- nWritten, _ := writer.writePackets(&rt, pkts)
+ nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
@@ -2455,7 +2532,7 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
@@ -2821,3 +2898,160 @@ func TestFragmentationErrors(t *testing.T) {
})
}
}
+
+func TestForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ )
+
+ ipv6Addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ }
+ ipv6Addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("11::1").To16()),
+ PrefixLen: 64,
+ }
+ remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
+ remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
+
+ tests := []struct {
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ }{
+ {
+ name: "TTL of zero",
+ TTL: 0,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of one",
+ TTL: 1,
+ expectErrorICMP: true,
+ },
+ {
+ name: "TTL of two",
+ TTL: 2,
+ expectErrorICMP: false,
+ },
+ {
+ name: "TTL of three",
+ TTL: 3,
+ expectErrorICMP: false,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectErrorICMP: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
+ if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
+ if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: ipv6Addr1.Subnet(),
+ NIC: nicID1,
+ },
+ {
+ Destination: ipv6Addr2.Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ if err := s.SetForwarding(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
+ }
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
+ icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv6EchoRequest)
+ icmp.SetCode(header.ICMPv6UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: test.TTL,
+ SrcAddr: remoteIPv6Addr1,
+ DstAddr: remoteIPv6Addr2,
+ })
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e1.InjectInbound(ProtocolNumber, requestPkt)
+
+ if test.expectErrorICMP {
+ reply, ok := e1.Read()
+ if !ok {
+ t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(ipv6Addr1.Address),
+ checker.DstAddr(remoteIPv6Addr1),
+ checker.TTL(DefaultTTL),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6HopLimitExceeded),
+ checker.ICMPv6Payload([]byte(hdr.View())),
+ ),
+ )
+
+ if n := e2.Drain(); n != 0 {
+ t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ }
+ } else {
+ reply, ok := e2.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
+ }
+
+ checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(remoteIPv6Addr1),
+ checker.DstAddr(remoteIPv6Addr2),
+ checker.TTL(test.TTL-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode),
+ checker.ICMPv6Payload(nil),
+ ),
+ )
+
+ if n := e1.Drain(); n != 0 {
+ t.Fatalf("got e1.Drain() = %d, want = 0", n)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
new file mode 100644
index 000000000..e8d1e7a79
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -0,0 +1,262 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // UnsolicitedReportIntervalMax is the maximum delay between sending
+ // unsolicited MLD reports.
+ //
+ // Obtained from RFC 2710 Section 7.10.
+ UnsolicitedReportIntervalMax = 10 * time.Second
+)
+
+// MLDOptions holds options for MLD.
+type MLDOptions struct {
+ // Enabled indicates whether MLD will be performed.
+ //
+ // When enabled, MLD may transmit MLD report and done messages when
+ // joining and leaving multicast groups respectively, and handle incoming
+ // MLD packets.
+ //
+ // This field is ignored and is always assumed to be false for interfaces
+ // without neighbouring nodes (e.g. loopback).
+ Enabled bool
+}
+
+var _ ip.MulticastGroupProtocol = (*mldState)(nil)
+
+// mldState is the per-interface MLD state.
+//
+// mldState.init MUST be called to initialize the MLD state.
+type mldState struct {
+ // The IPv6 endpoint this mldState is for.
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
+}
+
+// Enabled implements ip.MulticastGroupProtocol.
+func (mld *mldState) Enabled() bool {
+ // No need to perform MLD on loopback interfaces since they don't have
+ // neighbouring nodes.
+ return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled()
+}
+
+// SendReport implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+ return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
+}
+
+// SendLeave implements ip.MulticastGroupProtocol.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ return err
+}
+
+// init sets up an mldState struct, and is required to be called before using
+// a new mldState.
+//
+// Must only be called once for the lifetime of mld.
+func (mld *mldState) init(ep *endpoint) {
+ mld.ep = ep
+ mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Rand: ep.protocol.stack.Rand(),
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: mld,
+ MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
+ AllNodesAddress: header.IPv6AllNodesMulticastAddress,
+ })
+}
+
+// handleMulticastListenerQuery handles a query message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+}
+
+// handleMulticastListenerReport handles a report message.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
+ mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress())
+}
+
+// joinGroup handles joining a new group and sending and scheduling the required
+// messages.
+//
+// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress)
+}
+
+// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
+ return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
+}
+
+// leaveGroup handles removing the group from the membership map, cancels any
+// delay timers associated with that group, and sends the Done message, if
+// required.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+ // LeaveGroup returns false only if the group was not joined.
+ if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
+ return nil
+ }
+
+ return tcpip.ErrBadLocalAddress
+}
+
+// softLeaveAll leaves all groups from the perspective of MLD, but remains
+// joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) softLeaveAll() {
+ mld.genericMulticastProtocol.MakeAllNonMemberLocked()
+}
+
+// initializeAll attemps to initialize the MLD state for each group that has
+// been joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) initializeAll() {
+ mld.genericMulticastProtocol.InitializeGroupsLocked()
+}
+
+// sendQueuedReports attempts to send any reports that are queued for sending.
+//
+// Precondition: mld.ep.mu must be locked.
+func (mld *mldState) sendQueuedReports() {
+ mld.genericMulticastProtocol.SendQueuedReportsLocked()
+}
+
+// writePacket assembles and sends an MLD packet.
+//
+// Precondition: mld.ep.mu must be read locked.
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
+ sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ var mldStat *tcpip.StatCounter
+ switch mldType {
+ case header.ICMPv6MulticastListenerReport:
+ mldStat = sentStats.MulticastListenerReport
+ case header.ICMPv6MulticastListenerDone:
+ mldStat = sentStats.MulticastListenerDone
+ default:
+ panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
+ }
+
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize))
+ icmp.SetType(mldType)
+ header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress)
+ // As per RFC 2710 section 3,
+ //
+ // All MLD messages described in this document are sent with a link-local
+ // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert
+ // option in a Hop-by-Hop Options header.
+ //
+ // However, this would cause problems with Duplicate Address Detection with
+ // the first address as MLD snooping switches may not send multicast traffic
+ // that DAD depends on to the node performing DAD without the MLD report, as
+ // documented in RFC 4816:
+ //
+ // Note that when a node joins a multicast address, it typically sends a
+ // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810]
+ // for the multicast address. In the case of Duplicate Address
+ // Detection, the MLD report message is required in order to inform MLD-
+ // snooping switches, rather than routers, to forward multicast packets.
+ // In the above description, the delay for joining the multicast address
+ // thus means delaying transmission of the corresponding MLD report
+ // message. Since the MLD specifications do not request a random delay
+ // to avoid race conditions, just delaying Neighbor Solicitation would
+ // cause congestion by the MLD report messages. The congestion would
+ // then prevent the MLD-snooping switches from working correctly and, as
+ // a result, prevent Duplicate Address Detection from working. The
+ // requirement to include the delay for the MLD report in this case
+ // avoids this scenario. [RFC3590] also talks about some interaction
+ // issues between Duplicate Address Detection and MLD, and specifies
+ // which source address should be used for the MLD report in this case.
+ //
+ // As per RFC 3590 section 4, we should still send out MLD reports with an
+ // unspecified source address if we do not have an assigned link-local
+ // address to use as the source address to ensure DAD works as expected on
+ // networks with MLD snooping switches:
+ //
+ // MLD Report and Done messages are sent with a link-local address as
+ // the IPv6 source address, if a valid address is available on the
+ // interface. If a valid link-local address is not available (e.g., one
+ // has not been configured), the message is sent with the unspecified
+ // address (::) as the IPv6 source address.
+ //
+ // Once a valid link-local address is available, a node SHOULD generate
+ // new MLD Report messages for all multicast addresses joined on the
+ // interface.
+ //
+ // Routers receiving an MLD Report or Done message with the unspecified
+ // address as the IPv6 source address MUST silently discard the packet
+ // without taking any action on the packets contents.
+ //
+ // Snooping switches MUST manage multicast forwarding state based on MLD
+ // Report and Done messages sent with the unspecified address as the
+ // IPv6 source address.
+ localAddress := mld.ep.getLinkLocalAddressRLocked()
+ if len(localAddress) == 0 {
+ localAddress = header.IPv6Any
+ }
+
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{}))
+
+ extensionHeaders := header.IPv6ExtHdrSerializer{
+ header.IPv6SerializableHopByHopExtHdr{
+ &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
+ },
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+
+ mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.MLDHopLimit,
+ }, extensionHeaders)
+ if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sentStats.Dropped.Increment()
+ return false, err
+ }
+ mldStat.Increment()
+ return localAddress != header.IPv6Any, nil
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
new file mode 100644
index 000000000..f6ffa7133
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -0,0 +1,297 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+)
+
+var (
+ linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
+ globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
+)
+
+func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
+ t.Helper()
+
+ checker.IPv6WithExtHdr(t, p,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(localAddress),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(mldType, header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(0),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ // The stack will join an address's solicited node multicast address when
+ // an address is added. An MLD report message should be sent for the
+ // solicited-node group.
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+
+ // The stack will leave an address's solicited node multicast address when
+ // an address is removed. An MLD done message should be sent for the
+ // solicited-node group.
+ if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a done message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
+ }
+}
+
+func TestSendQueuedMLDReports(t *testing.T) {
+ const (
+ nicID = 1
+ maxReports = 2
+ )
+
+ tests := []struct {
+ name string
+ dadTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD Disabled",
+ dadTransmits: 0,
+ retransmitTimer: 0,
+ },
+ {
+ name: "DAD Enabled",
+ dadTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dadTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ Clock: clock,
+ })
+
+ // Allow space for an extra packet so we can observe packets that were
+ // unexpectedly sent.
+ e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ resolveDAD := func(addr, snmc tcpip.Address) {
+ clock.Advance(dadResolutionTime)
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected DAD packet")
+ } else {
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr),
+ checker.NDPNSOptions(nil),
+ ))
+ }
+ }
+
+ var reportCounter uint64
+ reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ var doneCounter uint64
+ doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+
+ // Joining a group without an assigned address should send an MLD report
+ // with the unspecified address.
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalMulticastAddr)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a global address should not send reports for the already joined
+ // group since we should only send queued reports when a link-local
+ // addres sis assigned.
+ //
+ // Note, we will still expect to send a report for the global address's
+ // solicited node address from the unspecified address as per RFC 3590
+ // section 4.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", globalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
+ }
+ if dadResolutionTime != 0 {
+ // Reports should not be sent when the address resolves.
+ resolveDAD(globalAddr, globalAddrSNMC)
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ }
+ // Leave the group since we don't care about the global address's
+ // solicited node multicast group membership.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
+ }
+ if got := doneStat.Value(); got != doneCounter {
+ t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Adding a link-local address should send a report for its solicited node
+ // address and globalMulticastAddr.
+ if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ }
+ if dadResolutionTime != 0 {
+ reportCounter++
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+ resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
+ }
+
+ // We expect two batches of reports to be sent (1 batch when the
+ // link-local address is assigned, and another after the maximum
+ // unsolicited report interval.
+ for i := 0; i < 2; i++ {
+ // We expect reports to be sent (one for globalMulticastAddr and another
+ // for linkLocalAddrSNMC).
+ reportCounter += maxReports
+ if got := reportStat.Value(); got != reportCounter {
+ t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+
+ addrs := map[tcpip.Address]bool{
+ globalMulticastAddr: false,
+ linkLocalAddrSNMC: false,
+ }
+ for range addrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
+ }
+
+ addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
+ if seen, ok := addrs[addr]; !ok {
+ t.Fatalf("got unexpected packet destined to %s", addr)
+ } else if seen {
+ t.Fatalf("got another packet destined to %s", addr)
+ }
+
+ addrs[addr] = true
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
+
+ clock.Advance(ipv6.UnsolicitedReportIntervalMax)
+ }
+ }
+
+ // Should not send any more reports.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 40da011f8..d515eb622 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -20,6 +20,7 @@ import (
"math/rand"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
// The IPv6 endpoint this ndpState is for.
ep *endpoint
@@ -471,17 +475,8 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- rtrSolicit struct {
- // The timer used to send the next router solicitation message.
- timer tcpip.Timer
-
- // Used to let the Router Solicitation timer know that it has been stopped.
- //
- // Must only be read from or written to while protected by the lock of
- // the IPv6 endpoint this ndpState is associated with. MUST be set when the
- // timer is set.
- done *bool
- }
+ // The job used to send the next router solicitation message.
+ rtrSolicitJob *tcpip.Job
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -507,7 +502,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer tcpip.Timer
+ job *tcpip.Job
// Used to let the DAD timer know that it has been stopped.
//
@@ -648,96 +643,73 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
+ ndp.ep.onAddressAssignedLocked(addr)
return nil
}
- var done bool
- var timer tcpip.Timer
- // We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the IPv6 endpoint's lock. This is effectively
- // the same as starting a goroutine but we use a timer that fires immediately
- // so we can reset it for the next DAD iteration.
- timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
- ndp.ep.mu.Lock()
- defer ndp.ep.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the IPv6 endpoint lock and stopped
- // DAD before this function obtained the NIC lock. Simply return here and
- // do nothing further.
- return
- }
+ state := dadState{
+ job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.dad[addr]
+ if !ok {
+ panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- // The endpoint should still be marked as tentative since we are still
- // performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- dadDone := remaining == 0
-
- var err *tcpip.Error
- if !dadDone {
- // Use the unspecified address as the source address when performing DAD.
- addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
-
- // Do not hold the lock when sending packets which may be a long running
- // task or may block link address resolution. We know this is safe
- // because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the IPv6 endpoint. Note,
- // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
- // the address was removed.
- ndp.ep.mu.Unlock()
- err = ndp.sendDADPacket(addr, addressEndpoint)
- ndp.ep.mu.Lock()
- addressEndpoint.DecRef()
- }
+ dadDone := remaining == 0
- if done {
- // If we reach this point, it means that DAD was stopped after we released
- // the IPv6 endpoint's read lock and before we obtained the write lock.
- return
- }
+ var err *tcpip.Error
+ if !dadDone {
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ }
- if dadDone {
- // DAD has resolved.
- addressEndpoint.SetKind(stack.Permanent)
- } else if err == nil {
- // DAD is not done and we had no errors when sending the last NDP NS,
- // schedule the next DAD timer.
- remaining--
- timer.Reset(ndp.configs.RetransmitTimer)
- return
- }
+ if dadDone {
+ // DAD has resolved.
+ addressEndpoint.SetKind(stack.Permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ state.job.Schedule(ndp.configs.RetransmitTimer)
+ return
+ }
- // At this point we know that either DAD is done or we hit an error sending
- // the last NDP NS. Either way, clean up addr's DAD state and let the
- // integrator know DAD has completed.
- delete(ndp.dad, addr)
+ // At this point we know that either DAD is done or we hit an error
+ // sending the last NDP NS. Either way, clean up addr's DAD state and let
+ // the integrator know DAD has completed.
+ delete(ndp.dad, addr)
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
- }
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ }
- // If DAD resolved for a stable SLAAC address, attempt generation of a
- // temporary SLAAC address.
- if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the generation
- // of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
- }
- })
+ if dadDone {
+ if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the
+ // generation of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
- ndp.dad[addr] = dadState{
- timer: timer,
- done: &done,
+ ndp.ep.onAddressAssignedLocked(addr)
+ }
+ }),
}
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ state.job.Schedule(0)
+ ndp.dad[addr] = state
+
return nil
}
@@ -745,55 +717,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-//
-// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer r.Release()
-
- // Route should resolve immediately since snmc is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is not
- // locked, the NIC could have been disabled or removed by another goroutine.
- if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
- return err
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
- }
-
- icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
- icmpData.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ns.SetTargetAddress(addr)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(icmpData).ToVectorisedView(),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
return err
}
sent.NeighborSolicit.Increment()
-
return nil
}
@@ -812,18 +760,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
return
}
- if dad.timer != nil {
- dad.timer.Stop()
- dad.timer = nil
-
- *dad.done = true
- dad.done = nil
- }
-
+ dad.job.Cancel()
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
@@ -846,7 +787,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
// only inform the dispatcher on configuration changes. We do nothing else
// with the information.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
var configuration DHCPv6ConfigurationFromNDPRA
switch {
case ra.ManagedAddrConfFlag():
@@ -903,20 +844,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
switch opt := opt.(type) {
case header.NDPRecursiveDNSServer:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
addrs, _ := opt.Addresses()
- ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
- if ndp.ep.protocol.ndpDisp == nil {
+ if ndp.ep.protocol.options.NDPDisp == nil {
continue
}
domainNames, _ := opt.DomainNames()
- ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -964,7 +905,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
@@ -976,7 +917,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1006,7 +947,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
@@ -1047,7 +988,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1225,7 +1166,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
// Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.ep.protocol.ndpDisp
+ ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return nil
}
@@ -1272,7 +1213,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
}
dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
- if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
@@ -1676,7 +1617,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint
}
addressEndpoint.SetDeprecated(true)
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
@@ -1701,7 +1642,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1761,7 +1702,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
- if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
@@ -1859,7 +1800,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicit.timer != nil {
+ if ndp.rtrSolicitJob != nil {
// We are already soliciting routers.
return
}
@@ -1876,56 +1817,14 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- var done bool
- ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
- ndp.ep.mu.Lock()
- if done {
- // If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the IPv6 endpoint lock and stopped
- // solicitations. Simply return here and do nothing further.
- ndp.ep.mu.Unlock()
- return
- }
-
+ ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
- if addressEndpoint == nil {
- // Incase this ends up creating a new temporary address, we need to hold
- // onto the endpoint until a route is obtained. If we decrement the
- // reference count before obtaing a route, the address's resources would
- // be released and attempting to obtain a route after would fail. Once a
- // route is obtainted, it is safe to decrement the reference count since
- // obtaining a route increments the address's reference count.
- addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
- }
- ndp.ep.mu.Unlock()
-
- localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
- addressEndpoint.DecRef()
- if err != nil {
- return
- }
- defer r.Release()
-
- // Route should resolve immediately since
- // header.IPv6AllRoutersMulticastAddress is a multicast address so a
- // remote link address can be calculated without a resolution process.
- if c, err := r.Resolve(nil); err != nil {
- // Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the IPv6 endpoint is
- // not locked, the IPv6 endpoint could have been disabled or removed by
- // another goroutine.
- if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
- return
- }
-
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
- } else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
+ localAddr := header.IPv6Any
+ if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1936,30 +1835,31 @@ func (ndp *ndpState) startSolicitingRouters() {
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
var optsSerializer header.NDPOptionsSerializer
- if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ linkAddress := ndp.ep.nic.LinkAddress()
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) {
optsSerializer = header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ header.NDPSourceLinkLayerAddressOption(linkAddress),
}
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
icmpData := header.ICMPv6(buffer.NewView(payloadSize))
icmpData.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs := header.NDPRouterSolicit(icmpData.MessageBody())
rs.Options().Serialize(optsSerializer)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
Data: buffer.View(icmpData).ToVectorisedView(),
})
- sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil,
- stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, pkt,
- ); err != nil {
+ sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
+ ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */)
+
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
@@ -1969,21 +1869,12 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.ep.mu.Lock()
- if done || remaining == 0 {
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
- } else if ndp.rtrSolicit.timer != nil {
- // Note, we need to explicitly check to make sure that
- // the timer field is not nil because if it was nil but
- // we still reached this point, then we know the IPv6 endpoint
- // was requested to stop soliciting routers so we don't
- // need to send the next Router Solicitation message.
- ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ if remaining != 0 {
+ ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval)
}
- ndp.ep.mu.Unlock()
})
+ ndp.rtrSolicitJob.Schedule(delay)
}
// stopSolicitingRouters stops soliciting routers. If routers are not currently
@@ -1991,22 +1882,28 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicit.timer == nil {
+ if ndp.rtrSolicitJob == nil {
// Nothing to do.
return
}
- *ndp.rtrSolicit.done = true
- ndp.rtrSolicit.timer.Stop()
- ndp.rtrSolicit.timer = nil
- ndp.rtrSolicit.done = nil
+ ndp.rtrSolicitJob.Cancel()
+ ndp.rtrSolicitJob = nil
}
-// initializeTempAddrState initializes state related to temporary SLAAC
-// addresses.
-func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
+func (ndp *ndpState) init(ep *endpoint) {
+ if ndp.dad != nil {
+ panic("attempted to initialize NDP state twice")
+ }
+
+ ndp.ep = ep
+ ndp.configs = ep.protocol.options.NDPConfigs
+ ndp.dad = make(map[tcpip.Address]dadState)
+ ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
+ ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 981d1371a..b1a5a5510 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
-
{
subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
if err != nil {
@@ -73,6 +69,17 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
}
t.Cleanup(ep.Close)
+ addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
+ }
+ addr := llladdr.WithPrefix()
+ if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ } else {
+ addressEP.DecRef()
+ }
+
return s, ep
}
@@ -198,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -206,14 +213,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -304,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(lladdr0)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -312,23 +319,23 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -574,7 +581,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: 1,
},
@@ -584,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(nicAddr)
opts := ns.Options()
opts.Serialize(test.nsOpts)
@@ -592,14 +599,14 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -665,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(test.nsSrc)
@@ -674,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
})
e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -770,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -778,14 +785,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
@@ -883,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns := header.NDPNeighborAdvert(pkt.MessageBody())
ns.SetTargetAddress(lladdr1)
opts := ns.Options()
copy(opts, test.optsBuf)
@@ -891,23 +898,23 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
})
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
neighbors, err := s.Neighbors(nicID)
if err != nil {
@@ -961,46 +968,36 @@ func TestNDPValidation(t *testing.T) {
for _, stackTyp := range stacks {
t.Run(stackTyp.name, func(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
// Create a stack with the assigned link-local address lladdr0
// and an endpoint to lladdr1.
s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
+ return s, ep
}
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
+ var extHdrs header.IPv6ExtHdrSerializer
if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
}
+ extHdrsLen := extHdrs.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen,
Data: payload.ToVectorisedView(),
})
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
+ PayloadLength: uint16(len(payload) + extHdrsLen),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: hopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ ExtensionHeaders: extHdrs,
})
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
- r.PopulatePacketInfo(pkt)
ep.HandlePacket(pkt)
}
@@ -1114,15 +1111,14 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ s, ep := setup(t)
if isRouter {
// Enabling forwarding makes the stack act as a router.
s.SetForwarding(ProtocolNumber, true)
}
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
typStat := typ.statCounter(stats)
@@ -1131,7 +1127,7 @@ func TestNDPValidation(t *testing.T) {
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
// Rx count of the NDP message should initially be 0.
if got := typStat.Value(); got != 0 {
@@ -1152,7 +1148,7 @@ func TestNDPValidation(t *testing.T) {
t.FailNow()
}
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
// Rx count of the NDP packet should have increased.
if got := typStat.Value(); got != 1 {
@@ -1346,19 +1342,19 @@ func TestRouterAdvertValidation(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
+ copy(pkt.MessageBody(), test.ndpPayload)
payloadLength := hdr.UsedLength()
pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
- stats := s.Stats().ICMP.V6PacketsReceived
+ stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
rxRA := stats.RouterAdvert
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
new file mode 100644
index 000000000..0f4f0e1e1
--- /dev/null
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -0,0 +1,1261 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+ ipv4Addr = tcpip.Address("\x0a\x00\x00\x01")
+ ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+
+ ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03")
+ ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04")
+ ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05")
+ ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04")
+ ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05")
+
+ igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
+ igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
+ igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
+ igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
+ mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
+ mldReport = uint8(header.ICMPv6MulticastListenerReport)
+ mldDone = uint8(header.ICMPv6MulticastListenerDone)
+
+ maxUnsolicitedReports = 2
+)
+
+var (
+ // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
+ // NIC will wait before sending an unsolicited report after joining a
+ // multicast group, in deciseconds.
+ unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
+ const decisecond = time.Second / 10
+ if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
+ panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
+ }
+ return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
+ }()
+
+ ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr)
+)
+
+// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
+// sent to the provided address with the passed fields set.
+func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv6WithExtHdr(t, payload,
+ checker.IPv6ExtHdr(
+ checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
+ ),
+ checker.SrcAddr(ipv6Addr),
+ checker.DstAddr(remoteAddress),
+ // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
+ checker.TTL(1),
+ checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
+ checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
+ checker.MLDMulticastAddress(groupAddress),
+ ),
+ )
+}
+
+// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
+// sent to the provided address with the passed fields set.
+func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
+ t.Helper()
+
+ payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ checker.IPv4(t, payload,
+ checker.SrcAddr(ipv4Addr),
+ checker.DstAddr(remoteAddress),
+ // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
+ checker.TTL(1),
+ checker.IPv4RouterAlert(),
+ checker.IGMP(
+ checker.IGMPType(header.IGMPType(igmpType)),
+ checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
+ checker.IGMPGroupAddress(groupAddress),
+ ),
+ )
+}
+
+func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
+ s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
+ return e, s, clock
+}
+
+func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
+ t.Helper()
+
+ igmpEnabled := v4 && mgpEnabled
+ mldEnabled := !v4 && mgpEnabled
+
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ IGMP: ipv4.IGMPOptions{
+ Enabled: igmpEnabled,
+ },
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: mldEnabled,
+ },
+ }),
+ },
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err)
+ }
+
+ return s, clock
+}
+
+// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
+// when it is created with an IPv6 address.
+//
+// To not interfere with tests, checkInitialIPv6Groups will leave the added
+// address's solicited node multicast group so that the tests can all assume
+// the NIC has not joined any IPv6 groups.
+func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
+ t.Helper()
+
+ stats := s.Stats().ICMP.V6.PacketsSent
+
+ reportCounter++
+ if got := stats.MulticastListenerReport.Value(); got != reportCounter {
+ t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
+ }
+
+ // Leave the group to not affect the tests. This is fine since we are not
+ // testing DAD or the solicited node address specifically.
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
+ }
+ leaveCounter++
+ if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
+ t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ return reportCounter, leaveCounter
+}
+
+// createAndInjectIGMPPacket creates and injects an IGMP packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
+ buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize)
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(len(buf)),
+ TTL: header.IGMPTTL,
+ Protocol: uint8(header.IGMPProtocolNumber),
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv4AllSystems,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ igmp := header.IGMP(buf[header.IPv4MinimumSize:])
+ igmp.SetType(header.IGMPType(igmpType))
+ igmp.SetMaxRespTime(maxRespTime)
+ igmp.SetGroupAddress(groupAddress)
+ igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
+
+ e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// createAndInjectMLDPacket creates and injects an MLD packet with the
+// specified fields.
+//
+// Note, the router alert option is not included in this packet.
+//
+// TODO(b/162198658): set the router alert option.
+func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
+ icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ buf := buffer.NewView(header.IPv6MinimumSize + icmpSize)
+
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ HopLimit: header.MLDHopLimit,
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ SrcAddr: header.IPv4Any,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ icmp := header.ICMPv6(buf[header.IPv6MinimumSize:])
+ icmp.SetType(header.ICMPv6Type(mldType))
+ mld := header.MLD(icmp.MessageBody())
+ mld.SetMaximumResponseDelay(uint16(maxRespDelay))
+ mld.SetMulticastAddress(groupAddress)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+
+ e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+// TestMGPDisabled tests that the multicast group protocol is not enabled by
+// default.
+func TestMGPDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
+
+ // This NIC may join multicast groups when it is enabled but since MGP is
+ // disabled, no reports should be sent.
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
+ }
+
+ // Inject a general query message. This should only trigger a report to be
+ // sent if the MGP was enabled.
+ test.rxQuery(e)
+ if got := test.receivedQueryStat(s).Value(); got != 1 {
+ t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
+ }
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
+ }
+ })
+ }
+}
+
+func TestMGPReceiveCounters(t *testing.T) {
+ tests := []struct {
+ name string
+ headerType uint8
+ maxRespTime byte
+ groupAddress tcpip.Address
+ statCounter func(*stack.Stack) *tcpip.StatCounter
+ rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
+ }{
+ {
+ name: "IGMP Membership Query",
+ headerType: igmpMembershipQuery,
+ maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
+ groupAddress: header.IPv4Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv1 Membership Report",
+ headerType: igmpv1MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V1MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMPv2 Membership Report",
+ headerType: igmpv2MembershipReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllSystems,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.V2MembershipReport
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "IGMP Leave Group",
+ headerType: igmpLeaveGroup,
+ maxRespTime: 0,
+ groupAddress: header.IPv4AllRoutersGroup,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.LeaveGroup
+ },
+ rxMGPkt: createAndInjectIGMPPacket,
+ },
+ {
+ name: "MLD Query",
+ headerType: mldQuery,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Report",
+ headerType: mldReport,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ {
+ name: "MLD Done",
+ headerType: mldDone,
+ maxRespTime: 0,
+ groupAddress: header.IPv6Any,
+ statCounter: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
+ },
+ rxMGPkt: createAndInjectMLDPacket,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
+
+ test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
+ if got := test.statCounter(s).Value(); got != 1 {
+ t.Fatalf("got %s received = %d, want = 1", test.name, got)
+ }
+ })
+ }
+}
+
+// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
+// stack schedules and sends correct Membership Reports.
+func TestMGPJoinGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ // Test joining a specific address explicitly and verify a Report is sent
+ // immediately.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Verify the second report is sent by the maximum unsolicited response
+ // interval.
+ p, ok := e.Read()
+ if ok {
+ t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPLeaveGroup tests that when leaving a previously joined multicast
+// group the stack sends a leave/done message.
+func TestMGPLeaveGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo)
+ validateLeave func(*testing.T, channel.PacketInfo)
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ reportCounter++
+ if got := test.sentReportStat(s).Value(); got != reportCounter {
+ t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving the group should trigger an leave/done message to be sent.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ leaveCounter++
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a leave message to be sent")
+ } else {
+ test.validateLeave(t, p)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that a report is sent in response to query
+// messages.
+func TestMGPQueryMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
+ rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsReceived.MembershipQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
+ },
+ rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
+ createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ multicastAddr tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Unspecified",
+ multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
+ expectReport: true,
+ },
+ {
+ name: "Specified",
+ multicastAddr: test.multicastAddr,
+ expectReport: true,
+ },
+ {
+ name: "Specified other address",
+ multicastAddr: func() tcpip.Address {
+ addrBytes := []byte(test.multicastAddr)
+ addrBytes[len(addrBytes)-1]++
+ return tcpip.Address(addrBytes)
+ }(),
+ expectReport: false,
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ for i := 0; i < maxUnsolicitedReports; i++ {
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected %d-th report message to be sent", i)
+ } else {
+ test.validateReport(t, p)
+ }
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Should not send any more packets until a query.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+
+ // Receive a query message which should trigger a report to be sent at
+ // some time before the maximum response time if the report is
+ // targeted at the host.
+ const maxRespTime = 100
+ test.rxQuery(e, maxRespTime, subTest.multicastAddr)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p.Pkt)
+ }
+
+ if subTest.expectReport {
+ clock.Advance(test.maxRespTimeToDuration(maxRespTime))
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestMGPQueryMessages tests that no further reports or leave/done messages
+// are sent after receiving a report.
+func TestMGPReportMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ rxReport func(*channel.Endpoint)
+ validateReport func(*testing.T, channel.PacketInfo)
+ maxRespTimeToDuration func(uint8) time.Duration
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
+ },
+ maxRespTimeToDuration: header.DecisecondToDuration,
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ rxReport: func(e *channel.Endpoint) {
+ createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo) {
+ t.Helper()
+
+ validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
+ },
+ maxRespTimeToDuration: func(d uint8) time.Duration {
+ return time.Duration(d) * time.Millisecond
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ sentReportStat := test.sentReportStat(s)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Receiving a report for a group we joined should cancel any further
+ // reports.
+ test.rxReport(e)
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Errorf("sent unexpected packet = %#v", p)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leaving a group after getting a report should not send a leave/done
+ // message.
+ if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
+ }
+ clock.Advance(time.Hour)
+ if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
+ t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+func TestMGPWithNICLifecycle(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddrs []tcpip.Address
+ finalMulticastAddr tcpip.Address
+ maxUnsolicitedResponseDelay time.Duration
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
+ validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
+ validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
+ getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
+ checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
+ finalMulticastAddr: ipv4MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.LeaveGroup
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
+ if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
+ t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
+ }
+ addr := header.IGMP(ipv4.Payload()).GroupAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
+ finalMulticastAddr: ipv6MulticastAddr3,
+ maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
+ },
+ validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, addr, mldReport, 0, addr)
+ },
+ validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
+ t.Helper()
+
+ validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ },
+ getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
+ t.Helper()
+
+ ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
+
+ ipv6HeaderIter := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var transport header.IPv6RawPayloadHeader
+ for {
+ h, done, err := ipv6HeaderIter.Next()
+ if err != nil {
+ t.Fatalf("ipv6HeaderIter.Next(): %s", err)
+ }
+ if done {
+ t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ }
+ if t, ok := h.(header.IPv6RawPayloadHeader); ok {
+ transport = t
+ break
+ }
+ }
+
+ if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
+ t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
+ }
+ icmpv6 := header.ICMPv6(transport.Buf.ToView())
+ if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
+ t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
+ }
+ addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
+ s, ok := seen[addr]
+ if !ok {
+ t.Fatalf("unexpectedly got a packet for group %s", addr)
+ }
+ if s {
+ t.Fatalf("already saw packet for group %s", addr)
+ }
+ seen[addr] = true
+ return addr
+ },
+ checkInitialGroups: checkInitialIPv6Groups,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
+
+ var reportCounter uint64
+ var leaveCounter uint64
+ if test.checkInitialGroups != nil {
+ reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
+ }
+
+ sentReportStat := test.sentReportStat(s)
+ for _, a := range test.multicastAddrs {
+ if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatalf("expected a report message to be sent for %s", a)
+ } else {
+ test.validateReport(t, p, a)
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Leave messages should be sent for the joined groups when the NIC is
+ // disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ sentLeaveStat := test.sentLeaveStat(s)
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+
+ test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Reports should be sent for the joined groups when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter += uint64(len(test.multicastAddrs))
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ {
+ seen := make(map[tcpip.Address]bool)
+ for _, a := range test.multicastAddrs {
+ seen[a] = false
+ }
+
+ for i := range test.multicastAddrs {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected (%d-th) report message to be sent", i)
+ }
+
+ test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
+ }
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Joining/leaving a group while disabled should not send any messages.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ leaveCounter += uint64(len(test.multicastAddrs))
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ for i := range test.multicastAddrs {
+ if _, ok := e.Read(); !ok {
+ t.Fatalf("expected (%d-th) leave message to be sent", i)
+ }
+ }
+ for _, a := range test.multicastAddrs {
+ if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
+ t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
+ }
+ if got := sentLeaveStat.Value(); got != leaveCounter {
+ t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
+ }
+ }
+ if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); ok {
+ t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
+ }
+
+ // A report should only be sent for the group we last joined after
+ // enabling the NIC since the original groups were all left.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ clock.Advance(test.maxUnsolicitedResponseDelay)
+ reportCounter++
+ if got := sentReportStat.Value(); got != reportCounter {
+ t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ test.validateReport(t, p, test.finalMulticastAddr)
+ }
+
+ // Should not send any more packets.
+ clock.Advance(time.Hour)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("sent unexpected packet = %#v", p)
+ }
+ })
+ }
+}
+
+// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
+// performed on loopback interfaces since they have no neighbours.
+func TestMGPDisabledOnLoopback(t *testing.T) {
+ tests := []struct {
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ multicastAddr tcpip.Address
+ sentReportStat func(*stack.Stack) *tcpip.StatCounter
+ }{
+ {
+ name: "IGMP",
+ protoNum: ipv4.ProtocolNumber,
+ multicastAddr: ipv4MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().IGMP.PacketsSent.V2MembershipReport
+ },
+ },
+ {
+ name: "MLD",
+ protoNum: ipv6.ProtocolNumber,
+ multicastAddr: ipv6MulticastAddr1,
+ sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
+ return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
+
+ sentReportStat := test.sentReportStat(s)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+
+ // Test joining a specific group explicitly and verify that no reports are
+ // sent.
+ if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
+ t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
+ }
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ clock.Advance(time.Hour)
+ if got := sentReportStat.Value(); got != 0 {
+ t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 7cc52985e..5c3363759 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st
return n, nil
}
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
-
- return nil
-}
-
// Attach implements LinkEndpoint.Attach.
func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 51d428049..4777163cd 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -44,6 +44,7 @@ import (
"bufio"
"fmt"
"log"
+ "math"
"math/rand"
"net"
"os"
@@ -200,7 +201,7 @@ func main() {
// connection from its side.
wq.EventRegister(&waitEntry, waiter.EventIn)
for {
- v, _, err := ep.Read(nil)
+ _, err := ep.Read(os.Stdout, math.MaxUint16, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrClosedForReceive {
break
@@ -213,8 +214,6 @@ func main() {
log.Fatal("Read() failed:", err)
}
-
- os.Stdout.Write(v)
}
wq.EventUnregister(&waitEntry)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 8e0ee1cd7..a80fa0474 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -20,8 +20,10 @@
package main
import (
+ "bytes"
"flag"
"log"
+ "math"
"math/rand"
"net"
"os"
@@ -54,7 +56,8 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
defer wq.EventUnregister(&waitEntry)
for {
- v, _, err := ep.Read(nil)
+ var buf bytes.Buffer
+ _, err := ep.Read(&buf, math.MaxUint16, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
@@ -64,7 +67,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
return
}
- ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ ep.Write(tcpip.SlicePayload(buf.Bytes()), tcpip.WriteOptions{})
}
}
@@ -148,10 +151,6 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- log.Fatal(err)
- }
-
subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
if err != nil {
log.Fatal(err)
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
new file mode 100644
index 000000000..f3ad40fdf
--- /dev/null
+++ b/pkg/tcpip/socketops.go
@@ -0,0 +1,520 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// SocketOptionsHandler holds methods that help define endpoint specific
+// behavior for socket level socket options. These must be implemented by
+// endpoints to get notified when socket level options are set.
+type SocketOptionsHandler interface {
+ // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint.
+ OnReuseAddressSet(v bool)
+
+ // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint.
+ OnReusePortSet(v bool)
+
+ // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint.
+ OnKeepAliveSet(v bool)
+
+ // OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint.
+ // Note that v will be the inverse of TCP_NODELAY option.
+ OnDelayOptionSet(v bool)
+
+ // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint.
+ OnCorkOptionSet(v bool)
+
+ // LastError is invoked when SO_ERROR is read for an endpoint.
+ LastError() *Error
+
+ // UpdateLastError updates the endpoint specific last error field.
+ UpdateLastError(err *Error)
+
+ // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
+ HasNIC(v int32) bool
+}
+
+// DefaultSocketOptionsHandler is an embeddable type that implements no-op
+// implementations for SocketOptionsHandler methods.
+type DefaultSocketOptionsHandler struct{}
+
+var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil)
+
+// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet.
+func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {}
+
+// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet.
+func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {}
+
+// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet.
+func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {}
+
+// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet.
+func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
+
+// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet.
+func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
+
+// LastError implements SocketOptionsHandler.LastError.
+func (*DefaultSocketOptionsHandler) LastError() *Error {
+ return nil
+}
+
+// UpdateLastError implements SocketOptionsHandler.UpdateLastError.
+func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {}
+
+// HasNIC implements SocketOptionsHandler.HasNIC.
+func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
+ return false
+}
+
+// SocketOptions contains all the variables which store values for SOL_SOCKET,
+// SOL_IP, SOL_IPV6 and SOL_TCP level options.
+//
+// +stateify savable
+type SocketOptions struct {
+ handler SocketOptionsHandler
+
+ // These fields are accessed and modified using atomic operations.
+
+ // broadcastEnabled determines whether datagram sockets are allowed to
+ // send packets to a broadcast address.
+ broadcastEnabled uint32
+
+ // passCredEnabled determines whether SCM_CREDENTIALS socket control
+ // messages are enabled.
+ passCredEnabled uint32
+
+ // noChecksumEnabled determines whether UDP checksum is disabled while
+ // transmitting for this socket.
+ noChecksumEnabled uint32
+
+ // reuseAddressEnabled determines whether Bind() should allow reuse of
+ // local address.
+ reuseAddressEnabled uint32
+
+ // reusePortEnabled determines whether to permit multiple sockets to be
+ // bound to an identical socket address.
+ reusePortEnabled uint32
+
+ // keepAliveEnabled determines whether TCP keepalive is enabled for this
+ // socket.
+ keepAliveEnabled uint32
+
+ // multicastLoopEnabled determines whether multicast packets sent over a
+ // non-loopback interface will be looped back.
+ multicastLoopEnabled uint32
+
+ // receiveTOSEnabled is used to specify if the TOS ancillary message is
+ // passed with incoming packets.
+ receiveTOSEnabled uint32
+
+ // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary
+ // message is passed with incoming packets.
+ receiveTClassEnabled uint32
+
+ // receivePacketInfoEnabled is used to specify if more inforamtion is
+ // provided with incoming packets such as interface index and address.
+ receivePacketInfoEnabled uint32
+
+ // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
+ // being written have an IP header and the endpoint should not attach an IP
+ // header.
+ hdrIncludedEnabled uint32
+
+ // v6OnlyEnabled is used to determine whether an IPv6 socket is to be
+ // restricted to sending and receiving IPv6 packets only.
+ v6OnlyEnabled uint32
+
+ // quickAckEnabled is used to represent the value of TCP_QUICKACK option.
+ // It currently does not have any effect on the TCP endpoint.
+ quickAckEnabled uint32
+
+ // delayOptionEnabled is used to specify if data should be sent out immediately
+ // by the transport protocol. For TCP, it determines if the Nagle algorithm
+ // is on or off.
+ delayOptionEnabled uint32
+
+ // corkOptionEnabled is used to specify if data should be held until segments
+ // are full by the TCP transport protocol.
+ corkOptionEnabled uint32
+
+ // receiveOriginalDstAddress is used to specify if the original destination of
+ // the incoming packet should be returned as an ancillary message.
+ receiveOriginalDstAddress uint32
+
+ // recvErrEnabled determines whether extended reliable error message passing
+ // is enabled.
+ recvErrEnabled uint32
+
+ // errQueue is the per-socket error queue. It is protected by errQueueMu.
+ errQueueMu sync.Mutex `state:"nosave"`
+ errQueue sockErrorList
+
+ // bindToDevice determines the device to which the socket is bound.
+ bindToDevice int32
+
+ // mu protects the access to the below fields.
+ mu sync.Mutex `state:"nosave"`
+
+ // linger determines the amount of time the socket should linger before
+ // close. We currently implement this option for TCP socket only.
+ linger LingerOption
+}
+
+// InitHandler initializes the handler. This must be called before using the
+// socket options utility.
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) {
+ so.handler = handler
+}
+
+func storeAtomicBool(addr *uint32, v bool) {
+ var val uint32
+ if v {
+ val = 1
+ }
+ atomic.StoreUint32(addr, val)
+}
+
+// SetLastError sets the last error for a socket.
+func (so *SocketOptions) SetLastError(err *Error) {
+ so.handler.UpdateLastError(err)
+}
+
+// GetBroadcast gets value for SO_BROADCAST option.
+func (so *SocketOptions) GetBroadcast() bool {
+ return atomic.LoadUint32(&so.broadcastEnabled) != 0
+}
+
+// SetBroadcast sets value for SO_BROADCAST option.
+func (so *SocketOptions) SetBroadcast(v bool) {
+ storeAtomicBool(&so.broadcastEnabled, v)
+}
+
+// GetPassCred gets value for SO_PASSCRED option.
+func (so *SocketOptions) GetPassCred() bool {
+ return atomic.LoadUint32(&so.passCredEnabled) != 0
+}
+
+// SetPassCred sets value for SO_PASSCRED option.
+func (so *SocketOptions) SetPassCred(v bool) {
+ storeAtomicBool(&so.passCredEnabled, v)
+}
+
+// GetNoChecksum gets value for SO_NO_CHECK option.
+func (so *SocketOptions) GetNoChecksum() bool {
+ return atomic.LoadUint32(&so.noChecksumEnabled) != 0
+}
+
+// SetNoChecksum sets value for SO_NO_CHECK option.
+func (so *SocketOptions) SetNoChecksum(v bool) {
+ storeAtomicBool(&so.noChecksumEnabled, v)
+}
+
+// GetReuseAddress gets value for SO_REUSEADDR option.
+func (so *SocketOptions) GetReuseAddress() bool {
+ return atomic.LoadUint32(&so.reuseAddressEnabled) != 0
+}
+
+// SetReuseAddress sets value for SO_REUSEADDR option.
+func (so *SocketOptions) SetReuseAddress(v bool) {
+ storeAtomicBool(&so.reuseAddressEnabled, v)
+ so.handler.OnReuseAddressSet(v)
+}
+
+// GetReusePort gets value for SO_REUSEPORT option.
+func (so *SocketOptions) GetReusePort() bool {
+ return atomic.LoadUint32(&so.reusePortEnabled) != 0
+}
+
+// SetReusePort sets value for SO_REUSEPORT option.
+func (so *SocketOptions) SetReusePort(v bool) {
+ storeAtomicBool(&so.reusePortEnabled, v)
+ so.handler.OnReusePortSet(v)
+}
+
+// GetKeepAlive gets value for SO_KEEPALIVE option.
+func (so *SocketOptions) GetKeepAlive() bool {
+ return atomic.LoadUint32(&so.keepAliveEnabled) != 0
+}
+
+// SetKeepAlive sets value for SO_KEEPALIVE option.
+func (so *SocketOptions) SetKeepAlive(v bool) {
+ storeAtomicBool(&so.keepAliveEnabled, v)
+ so.handler.OnKeepAliveSet(v)
+}
+
+// GetMulticastLoop gets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) GetMulticastLoop() bool {
+ return atomic.LoadUint32(&so.multicastLoopEnabled) != 0
+}
+
+// SetMulticastLoop sets value for IP_MULTICAST_LOOP option.
+func (so *SocketOptions) SetMulticastLoop(v bool) {
+ storeAtomicBool(&so.multicastLoopEnabled, v)
+}
+
+// GetReceiveTOS gets value for IP_RECVTOS option.
+func (so *SocketOptions) GetReceiveTOS() bool {
+ return atomic.LoadUint32(&so.receiveTOSEnabled) != 0
+}
+
+// SetReceiveTOS sets value for IP_RECVTOS option.
+func (so *SocketOptions) SetReceiveTOS(v bool) {
+ storeAtomicBool(&so.receiveTOSEnabled, v)
+}
+
+// GetReceiveTClass gets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) GetReceiveTClass() bool {
+ return atomic.LoadUint32(&so.receiveTClassEnabled) != 0
+}
+
+// SetReceiveTClass sets value for IPV6_RECVTCLASS option.
+func (so *SocketOptions) SetReceiveTClass(v bool) {
+ storeAtomicBool(&so.receiveTClassEnabled, v)
+}
+
+// GetReceivePacketInfo gets value for IP_PKTINFO option.
+func (so *SocketOptions) GetReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0
+}
+
+// SetReceivePacketInfo sets value for IP_PKTINFO option.
+func (so *SocketOptions) SetReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receivePacketInfoEnabled, v)
+}
+
+// GetHeaderIncluded gets value for IP_HDRINCL option.
+func (so *SocketOptions) GetHeaderIncluded() bool {
+ return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
+}
+
+// SetHeaderIncluded sets value for IP_HDRINCL option.
+func (so *SocketOptions) SetHeaderIncluded(v bool) {
+ storeAtomicBool(&so.hdrIncludedEnabled, v)
+}
+
+// GetV6Only gets value for IPV6_V6ONLY option.
+func (so *SocketOptions) GetV6Only() bool {
+ return atomic.LoadUint32(&so.v6OnlyEnabled) != 0
+}
+
+// SetV6Only sets value for IPV6_V6ONLY option.
+//
+// Preconditions: the backing TCP or UDP endpoint must be in initial state.
+func (so *SocketOptions) SetV6Only(v bool) {
+ storeAtomicBool(&so.v6OnlyEnabled, v)
+}
+
+// GetQuickAck gets value for TCP_QUICKACK option.
+func (so *SocketOptions) GetQuickAck() bool {
+ return atomic.LoadUint32(&so.quickAckEnabled) != 0
+}
+
+// SetQuickAck sets value for TCP_QUICKACK option.
+func (so *SocketOptions) SetQuickAck(v bool) {
+ storeAtomicBool(&so.quickAckEnabled, v)
+}
+
+// GetDelayOption gets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) GetDelayOption() bool {
+ return atomic.LoadUint32(&so.delayOptionEnabled) != 0
+}
+
+// SetDelayOption sets inverted value for TCP_NODELAY option.
+func (so *SocketOptions) SetDelayOption(v bool) {
+ storeAtomicBool(&so.delayOptionEnabled, v)
+ so.handler.OnDelayOptionSet(v)
+}
+
+// GetCorkOption gets value for TCP_CORK option.
+func (so *SocketOptions) GetCorkOption() bool {
+ return atomic.LoadUint32(&so.corkOptionEnabled) != 0
+}
+
+// SetCorkOption sets value for TCP_CORK option.
+func (so *SocketOptions) SetCorkOption(v bool) {
+ storeAtomicBool(&so.corkOptionEnabled, v)
+ so.handler.OnCorkOptionSet(v)
+}
+
+// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) GetReceiveOriginalDstAddress() bool {
+ return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0
+}
+
+// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option.
+func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
+ storeAtomicBool(&so.receiveOriginalDstAddress, v)
+}
+
+// GetRecvError gets value for IP*_RECVERR option.
+func (so *SocketOptions) GetRecvError() bool {
+ return atomic.LoadUint32(&so.recvErrEnabled) != 0
+}
+
+// SetRecvError sets value for IP*_RECVERR option.
+func (so *SocketOptions) SetRecvError(v bool) {
+ storeAtomicBool(&so.recvErrEnabled, v)
+ if !v {
+ so.pruneErrQueue()
+ }
+}
+
+// GetLastError gets value for SO_ERROR option.
+func (so *SocketOptions) GetLastError() *Error {
+ return so.handler.LastError()
+}
+
+// GetOutOfBandInline gets value for SO_OOBINLINE option.
+func (*SocketOptions) GetOutOfBandInline() bool {
+ return true
+}
+
+// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not
+// support disabling this option.
+func (*SocketOptions) SetOutOfBandInline(bool) {}
+
+// GetLinger gets value for SO_LINGER option.
+func (so *SocketOptions) GetLinger() LingerOption {
+ so.mu.Lock()
+ linger := so.linger
+ so.mu.Unlock()
+ return linger
+}
+
+// SetLinger sets value for SO_LINGER option.
+func (so *SocketOptions) SetLinger(linger LingerOption) {
+ so.mu.Lock()
+ so.linger = linger
+ so.mu.Unlock()
+}
+
+// SockErrOrigin represents the constants for error origin.
+type SockErrOrigin uint8
+
+const (
+ // SockExtErrorOriginNone represents an unknown error origin.
+ SockExtErrorOriginNone SockErrOrigin = iota
+
+ // SockExtErrorOriginLocal indicates a local error.
+ SockExtErrorOriginLocal
+
+ // SockExtErrorOriginICMP indicates an IPv4 ICMP error.
+ SockExtErrorOriginICMP
+
+ // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error.
+ SockExtErrorOriginICMP6
+)
+
+// IsICMPErr indicates if the error originated from an ICMP error.
+func (origin SockErrOrigin) IsICMPErr() bool {
+ return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6
+}
+
+// SockError represents a queue entry in the per-socket error queue.
+//
+// +stateify savable
+type SockError struct {
+ sockErrorEntry
+
+ // Err is the error caused by the errant packet.
+ Err *Error
+ // ErrOrigin indicates the error origin.
+ ErrOrigin SockErrOrigin
+ // ErrType is the type in the ICMP header.
+ ErrType uint8
+ // ErrCode is the code in the ICMP header.
+ ErrCode uint8
+ // ErrInfo is additional info about the error.
+ ErrInfo uint32
+
+ // Payload is the errant packet's payload.
+ Payload []byte
+ // Dst is the original destination address of the errant packet.
+ Dst FullAddress
+ // Offender is the original sender address of the errant packet.
+ Offender FullAddress
+ // NetProto is the network protocol being used to transmit the packet.
+ NetProto NetworkProtocolNumber
+}
+
+// pruneErrQueue resets the queue.
+func (so *SocketOptions) pruneErrQueue() {
+ so.errQueueMu.Lock()
+ so.errQueue.Reset()
+ so.errQueueMu.Unlock()
+}
+
+// DequeueErr dequeues a socket extended error from the error queue and returns
+// it. Returns nil if queue is empty.
+func (so *SocketOptions) DequeueErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+
+ err := so.errQueue.Front()
+ if err != nil {
+ so.errQueue.Remove(err)
+ }
+ return err
+}
+
+// PeekErr returns the error in the front of the error queue. Returns nil if
+// the error queue is empty.
+func (so *SocketOptions) PeekErr() *SockError {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ return so.errQueue.Front()
+}
+
+// QueueErr inserts the error at the back of the error queue.
+//
+// Preconditions: so.GetRecvError() == true.
+func (so *SocketOptions) QueueErr(err *SockError) {
+ so.errQueueMu.Lock()
+ defer so.errQueueMu.Unlock()
+ so.errQueue.PushBack(err)
+}
+
+// QueueLocalErr queues a local error onto the local queue.
+func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
+ so.QueueErr(&SockError{
+ Err: err,
+ ErrOrigin: SockExtErrorOriginLocal,
+ ErrInfo: info,
+ Payload: payload,
+ Dst: dst,
+ NetProto: net,
+ })
+}
+
+// GetBindToDevice gets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) GetBindToDevice() int32 {
+ return atomic.LoadInt32(&so.bindToDevice)
+}
+
+// SetBindToDevice sets value for SO_BINDTODEVICE option.
+func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error {
+ if !so.handler.HasNIC(bindToDevice) {
+ return ErrUnknownDevice
+ }
+
+ atomic.StoreInt32(&so.bindToDevice, bindToDevice)
+ return nil
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index d09ebe7fa..bb30556cf 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -112,7 +112,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
- shard_count = 20,
+ shard_count = most_shards,
deps = [
":stack",
"//pkg/rand",
@@ -120,6 +120,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
@@ -131,7 +132,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
@@ -148,7 +148,6 @@ go_test(
],
library = ":stack",
deps = [
- "//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 9478f3fb7..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
-}
-
-// ReadOnlyAddressableEndpointState provides read-only access to an
-// AddressableEndpointState.
-type ReadOnlyAddressableEndpointState struct {
- inner *AddressableEndpointState
}
-// AddrOrMatching returns an endpoint for the passed address that is consisdered
-// bound to the wrapped AddressableEndpointState.
+// GetAddress returns the AddressEndpoint for the passed address.
//
-// If addr is an exact match with an existing address, that address is returned.
-// Otherwise, f is called with each address and the address that f returns true
-// for is returned.
-//
-// Returns nil of no address matches.
-func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
-
- if ep, ok := m.inner.mu.endpoints[addr]; ok {
- if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
- return ep
- }
- }
-
- for _, ep := range m.inner.mu.endpoints {
- if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
- return ep
- }
- }
-
- return nil
-}
-
-// Lookup returns the AddressEndpoint for the passed address.
+// GetAddress does not increment the address's reference count or check if the
+// address is considered bound to the endpoint.
//
-// Returns nil if the passed address is not associated with the
-// AddressableEndpointState.
-func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Returns nil if the passed address is not associated with the endpoint.
+func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- ep, ok := m.inner.mu.endpoints[addr]
+ ep, ok := a.mu.endpoints[addr]
if !ok {
return nil
}
return ep
}
-// ForEach calls f for each address pair.
+// ForEachEndpoint calls f for each address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- for _, ep := range m.inner.mu.endpoints {
+ for _, ep := range a.mu.endpoints {
if !f(ep) {
return
}
@@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool)
// ForEachPrimaryEndpoint calls f for each primary address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
- for _, ep := range m.inner.mu.primary {
- f(ep)
- }
-}
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
-// ReadOnly returns a readonly reference to a.
-func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
- return ReadOnlyAddressableEndpointState{inner: a}
+ for _, ep := range a.mu.primary {
+ if !f(ep) {
+ return
+ }
+ }
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
@@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
return deprecatedEndpoint
}
-// AcquireAssignedAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+// AcquireAssignedAddressOrMatching returns an address endpoint that is
+// considered assigned to the addressable endpoint.
+//
+// If the address is an exact match with an existing address, that address is
+// returned. Otherwise, if f is provided, f is called with each address and
+// the address that f returns true for is returned.
+//
+// If there is no matching address, a temporary address will be returned if
+// allowTemp is true.
+//
+// Regardless how the address was obtained, it will be acquired before it is
+// returned.
+func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
@@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return addrState
}
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
+ }
+ }
+
if !allowTemp {
return nil
}
@@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB)
+}
+
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
@@ -588,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
- if err != nil {
- return false, err
- }
- // We have no need for the address endpoint.
- a.decAddressRefLocked(ep)
- }
-
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- a.removeGroupAddressLocked(group)
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
-func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
- if err := a.removePermanentAddressLocked(group); err != nil {
- // removePermanentEndpointLocked would only return an error if group is
- // not bound to the addressable endpoint, but we know it MUST be assigned
- // since we have group in our map of groups.
- panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
- }
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- for group := range a.mu.groups {
- a.removeGroupAddressLocked(group)
- }
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 9a17efcba..5e649cca6 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -142,19 +142,19 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
-// Precondition: ct.mu must be held.
-func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+// Precondition: cn.mu must be held.
+func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
// TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
// other tcp states.
- if ct.tcb.IsEmpty() {
- ct.tcb.Init(tcpHeader)
- } else if hook == ct.tcbHook {
- ct.tcb.UpdateStateOutbound(tcpHeader)
+ if cn.tcb.IsEmpty() {
+ cn.tcb.Init(tcpHeader)
+ } else if hook == cn.tcbHook {
+ cn.tcb.UpdateStateOutbound(tcpHeader)
} else {
- ct.tcb.UpdateStateInbound(tcpHeader)
+ cn.tcb.UpdateStateInbound(tcpHeader)
}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 7a501acdc..93e8e1c51 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
}
func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ netHdr := pkt.NetworkHeader().View()
+ _, dst := f.proto.ParseAddresses(netHdr)
+
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt)
+ return
+ }
+
+ r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ defer r.Release()
+
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ pkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: vv.ToView().ToVectorisedView(),
+ })
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ _ = r.WriteHeaderIncludedPacket(pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+ // The network header should not already be populated.
+ if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
}
func (f *fwdTestNetworkEndpoint) Close() {
@@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() {
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
+ stack *Stack
+
addrCache *linkAddrCache
neigh *neighborCache
addrResolveDelay time.Duration
@@ -280,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
@@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := fwdTestPacketInfo{
- Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
@@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
- NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }},
+ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
+ proto.stack = s
+ return proto
+ }},
UseNeighborCache: useNeighborCache,
})
@@ -542,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) {
}
}
+func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 50 * time.Millisecond,
+ onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) {
+ // Don't resolve the link address.
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */)
+
+ const numPackets int = 5
+ // These packets will all be enqueued in the packet queue to wait for link
+ // address resolution.
+ for i := 0; i < numPackets; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
+
+ // All packets should fail resolution.
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
+ for i := 0; i < numPackets; i++ {
+ select {
+ case got := <-ep2.C:
+ t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+}
+
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 2d8c883cd..09c7811fa 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -45,13 +45,13 @@ const reaperDelay = 5 * time.Second
func DefaultTables() *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
- NATID: Table{
+ NATID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -68,11 +68,11 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- MangleID: Table{
+ MangleID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -86,12 +86,12 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- FilterID: Table{
+ FilterID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
@@ -110,13 +110,13 @@ func DefaultTables() *IPTables {
},
},
v6Tables: [NumTables]Table{
- NATID: Table{
+ NATID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -133,11 +133,11 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- MangleID: Table{
+ MangleID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -151,12 +151,12 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- FilterID: Table{
+ FilterID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
@@ -175,9 +175,9 @@ func DefaultTables() *IPTables {
},
},
priorities: [NumHooks][]TableID{
- Prerouting: []TableID{MangleID, NATID},
- Input: []TableID{NATID, FilterID},
- Output: []TableID{MangleID, NATID, FilterID},
+ Prerouting: {MangleID, NATID},
+ Input: {NATID, FilterID},
+ Output: {MangleID, NATID, FilterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 4b86c1be9..56a3e7861 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -56,7 +56,7 @@ const (
// Postrouting happens just before a packet goes out on the wire.
Postrouting
- // The total number of hooks.
+ // NumHooks is the total number of hooks.
NumHooks
)
@@ -273,14 +273,12 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo
return true
}
- // If the interface name ends with '+', any interface which begins
- // with the name should be matched.
+ // If the interface name ends with '+', any interface which
+ // begins with the name should be matched.
ifName := fl.OutputInterface
- matches := true
+ matches := nicName == ifName
if strings.HasSuffix(ifName, "+") {
matches = strings.HasPrefix(nicName, ifName[:n-1])
- } else {
- matches = nicName == ifName
}
return fl.OutputInterfaceInvert != matches
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index c9b13cd0e..792f4f170 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -18,7 +18,6 @@ import (
"fmt"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -58,9 +57,6 @@ const (
incomplete entryState = iota
// ready means that the address has been resolved and can be used.
ready
- // failed means that address resolution timed out and the address
- // could not be resolved.
- failed
)
// String implements Stringer.
@@ -70,8 +66,6 @@ func (s entryState) String() string {
return "incomplete"
case ready:
return "ready"
- case failed:
- return "failed"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -80,40 +74,48 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ // linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
+ // TODO(gvisor.dev/issue/5150): move these fields under mu.
+ // mu protects the fields below.
+ mu sync.RWMutex
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
- // wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of incomplete these waiters are notified.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil iff
- // s is incomplete and resolution is not yet in progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
}
-// changeState sets the entry's state to ns, notifying any waiters.
+func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
+ for _, callback := range e.onResolve {
+ callback(linkAddr, len(linkAddr) != 0)
+ }
+ e.onResolve = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// changeStateLocked sets the entry's state to ns.
//
// The entry's expiration is bumped up to the greater of itself and the passed
// expiration; the zero value indicates immediate expiration, and is set
// unconditionally - this is an implementation detail that allows for entries
// to be reused.
-func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
- // Notify whoever is waiting on address resolution when transitioning
- // out of incomplete.
- if e.s == incomplete && ns != incomplete {
- for w := range e.wakers {
- w.Assert()
- }
- e.wakers = nil
- if ch := e.done; ch != nil {
- close(ch)
- }
- e.done = nil
+//
+// Precondition: e.mu must be locked
+func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
+ if e.s == incomplete && ns == ready {
+ e.notifyCompletionLocked(e.linkAddr)
}
if expiration.IsZero() || expiration.After(e.expiration) {
@@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
e.s = ns
}
-func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
- delete(e.wakers, w)
-}
-
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
@@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
c.cache.Lock()
entry := c.getOrCreateEntryLocked(k)
- entry.linkAddr = v
-
- entry.changeState(ready, expiration)
c.cache.Unlock()
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+ entry.linkAddr = v
+ entry.changeStateLocked(ready, expiration)
}
// getOrCreateEntryLocked retrieves a cache entry associated with k. The
@@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
var entry *linkAddrEntry
if len(c.cache.table) == linkAddrCacheSize {
entry = c.cache.lru.Back()
+ entry.mu.Lock()
delete(c.cache.table, entry.addr)
c.cache.lru.Remove(entry)
- // Wake waiters and mark the soon-to-be-reused entry as expired. Note
- // that the state passed doesn't matter when the zero time is passed.
- entry.changeState(failed, time.Time{})
+ // Wake waiters and mark the soon-to-be-reused entry as expired.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ entry.mu.Unlock()
} else {
entry = new(linkAddrEntry)
}
@@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ if onResolve != nil {
+ onResolve(addr, true)
+ }
return addr, nil, nil
}
}
@@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
+ case ready:
if !time.Now().After(entry.expiration) {
// Not expired.
- switch s {
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ if onResolve != nil {
+ onResolve(entry.linkAddr, true)
}
+ return entry.linkAddr, nil, nil
}
- entry.changeState(incomplete, time.Time{})
+ entry.changeStateLocked(incomplete, time.Time{})
fallthrough
case incomplete:
- if waker != nil {
- if entry.wakers == nil {
- entry.wakers = make(map[*sleep.Waker]struct{})
- }
- entry.wakers[waker] = struct{}{}
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
}
-
if entry.done == nil {
- // Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- }
-
entry.done = make(chan struct{})
go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
-
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker previously added through get().
-func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.cache.Lock()
- defer c.cache.Unlock()
-
- if entry, ok := c.cache.table[k]; ok {
- entry.removeWaker(waker)
- }
-}
-
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
@@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
}
}
-// checkLinkRequest checks whether previous attempt to resolve address has succeeded
-// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
-// can stop, false if another request should be sent.
+// checkLinkRequest checks whether previous attempt to resolve address has
+// succeeded and mark the entry accordingly. Returns true if request can stop,
+// false if another request should be sent.
func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
c.cache.Lock()
defer c.cache.Unlock()
@@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
// Entry was evicted from the cache.
return true
}
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
- // Entry was made ready by resolver or failed. Either way we're done.
+ case ready:
+ // Entry was made ready by resolver.
case incomplete:
if attempt+1 < c.resolutionAttempts {
// No response yet, need to send another ARP request.
return false
}
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed, now.Add(c.ageLimit))
+ // Max number of retries reached, delete entry.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ delete(c.cache.table, k)
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index d2e37f38d..6883045b5 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -21,7 +21,6 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -50,6 +49,7 @@ type testLinkAddressResolver struct {
}
func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
@@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
}
func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
+ var attemptedResolution bool
for {
- if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- return got, err
+ got, ch, err := c.get(addr, linkRes, "", nil, nil)
+ if err == tcpip.ErrWouldBlock {
+ if attemptedResolution {
+ return got, tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
+ <-ch
+ continue
}
- s.Fetch(true)
+ return got, err
}
}
@@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
+ c.cache.Lock()
+ defer c.cache.Unlock()
for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
e := testAddrs[i]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
}
func TestCacheConcurrent(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
for r := 0; r < 16; r++ {
@@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) {
go func() {
for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
- c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
wg.Done()
}()
@@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) {
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) {
}
e = testAddrs[0]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
+
e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err)
}
}
@@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}
-
-// TestCacheWaker verifies that RemoveWaker removes a waker previously added
-// through get().
-func TestCacheWaker(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
-
- // First, sanity check that wakers are working.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 1
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[0]
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
- id, ok := s.Fetch(true /* block */)
- if !ok {
- t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)")
- }
- if id != wakerID {
- t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID)
- }
-
- if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
- }
-
- // Check that RemoveWaker works.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 2 // different than the ID used in the sanity check
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[1]
- linkRes.onLinkAddressRequest = func() {
- // Remove the waker before the linkAddrCache has the opportunity to send
- // a notification.
- c.removeWaker(e.addr, &w)
- }
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
-
- if got, err := getBlocking(c, e.addr, linkRes); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Fatalf("unexpected notification from waker with id %d", id)
- }
- }
-}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 73a01c2dd..61636cae5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) {
}
// We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) {
}
// Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
@@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
}
@@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
@@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
@@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate an address conflict.
test.rxPkt(e, addr1)
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
@@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) {
}
// Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
t.Errorf("got NeighborSolicit = %d, want <= 1", got)
}
})
@@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
+ raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
binary.BigEndian.PutUint16(raPayload[2:], rl)
@@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
return stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
NDPConfigs: ipv6.NDPConfigurations{
AutoGenTempGlobalAddresses: true,
},
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
})},
})
@@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
@@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
@@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
+ AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
HandleRAs: true,
DiscoverDefaultRouters: true,
@@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
}
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
}
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
- remaining--
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
}
+ }
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
- waitForPkt(defaultAsyncPositiveEventTimeout)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt)
}
+ }
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
+
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
}
func TestStopStartSolicitingRouters(t *testing.T) {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 177bf5516..c15f10e76 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -17,16 +17,22 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
const neighborCacheSize = 512 // max entries per interface
+// NeighborStats holds metrics for the neighbor table.
+type NeighborStats struct {
+ // FailedEntryLookups counts the number of lookups performed on an entry in
+ // Failed state.
+ FailedEntryLookups *tcpip.StatCounter
+}
+
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
-// dynmically acquired entries. It contains the state machine and configuration
+// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
@@ -92,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
n.dynamic.lru.Remove(e)
n.dynamic.count--
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Unknown)
- e.notifyWakersLocked()
+ e.removeLocked()
e.mu.Unlock()
}
n.cache[remoteAddr] = entry
@@ -103,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
return entry
}
-// entry looks up the neighbor cache for translating address to link address
-// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
-// is a LinkAddressResolver registered with the network protocol, the cache
-// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
-// provided, it will be notified when address resolution is complete (success
-// or not).
+// entry looks up neighbor information matching the remote address, and returns
+// it if readily available.
+//
+// Returns ErrWouldBlock if the link address is not readily available, along
+// with a notification channel for the caller to block on. Triggers address
+// resolution asynchronously.
+//
+// If onResolve is provided, it will be called either immediately, if resolution
+// is not required, or when address resolution is complete, with the resolved
+// link address and whether resolution succeeded. After any callbacks have been
+// called, the returned notification channel is closed.
+//
+// NB: if a callback is provided, it should not call into the neighbor cache.
//
// If specified, the local address must be an address local to the interface the
// neighbor cache belongs to. The local address is the source address of a
// packet prompting NUD/link address resolution.
//
-// If address resolution is required, ErrNoLinkAddress and a notification
-// channel is returned for the top level caller to block. Channel is closed
-// once address resolution is complete (success or not).
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve.
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
Addr: remoteAddr,
@@ -125,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
State: Static,
UpdatedAtNanos: 0,
}
+ if onResolve != nil {
+ onResolve(linkAddr, true)
+ }
return e, nil, nil
}
@@ -142,47 +155,36 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// of packets to a neighbor. While reasserting a neighbor's reachability,
// a node continues sending packets to that neighbor using the cached
// link-layer address."
+ if onResolve != nil {
+ onResolve(entry.neigh.LinkAddr, true)
+ }
return entry.neigh, nil, nil
- case Unknown, Incomplete:
- entry.addWakerLocked(w)
-
+ case Unknown, Incomplete, Failed:
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
+ }
if entry.done == nil {
// Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
- }
entry.done = make(chan struct{})
}
-
entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
- case Failed:
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
- n.mu.Lock()
- if entry, ok := n.cache[addr]; ok {
- delete(entry.wakers, waker)
- }
- n.mu.Unlock()
-}
-
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(n.cache))
n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ entries := make([]NeighborEntry, 0, len(n.cache))
for _, entry := range n.cache {
entry.mu.RLock()
entries = append(entries, entry.neigh)
entry.mu.RUnlock()
}
- n.mu.RUnlock()
return entries
}
@@ -214,32 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
return
}
- // Notify that resolution has been interrupted, just in case the entry was
- // in the Incomplete or Probe state.
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
-// removeEntryLocked removes the specified entry from the neighbor cache.
-func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
- if entry.neigh.State != Static {
- n.dynamic.lru.Remove(entry)
- n.dynamic.count--
- }
- if entry.neigh.State != Failed {
- entry.dispatchRemoveEventLocked()
- }
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
-
- delete(n.cache, entry.neigh.Addr)
-}
-
// removeEntry removes a dynamic or static entry by address from the neighbor
// cache. Returns true if the entry was found and deleted.
func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
@@ -254,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
entry.mu.Lock()
defer entry.mu.Unlock()
- n.removeEntryLocked(entry)
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+
+ entry.removeLocked()
+ delete(n.cache, entry.neigh.Addr)
return true
}
@@ -265,9 +254,7 @@ func (n *neighborCache) clear() {
for _, entry := range n.cache {
entry.mu.Lock()
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index ed33418f3..a2ed6ae2a 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -28,7 +28,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
)
@@ -80,17 +79,20 @@ func entryDiffOptsWithSort() []cmp.Option {
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
config.resetInvalidFields()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- return &neighborCache{
+ neigh := &neighborCache{
nic: &NIC{
stack: &Stack{
clock: clock,
nudDisp: nudDisp,
},
- id: 1,
+ id: 1,
+ stats: makeNICStats(),
},
state: NewNUDState(config, rng),
cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
+ neigh.nic.neigh = neigh
+ return neigh
}
// testEntryStore contains a set of IP to NeighborEntry mappings.
@@ -187,15 +189,18 @@ type testNeighborResolver struct {
entries *testEntryStore
delay time.Duration
onLinkAddressRequest func()
+ dropReplies bool
}
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
- // Delay handling the request to emulate network latency.
- r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(targetAddr)
- })
+ if !r.dropReplies {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(targetAddr)
+ })
+ }
// Execute post address resolution action, if available.
if f := r.onLinkAddressRequest; f != nil {
@@ -288,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -324,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -351,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -410,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -458,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -510,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
}
// Expect to find only the most recent entries. The order of entries reported
- // by entries() is undeterministic, so entries have to be sorted before
+ // by entries() is nondeterministic, so entries have to be sorted before
// comparison.
wantUnsortedEntries := opts.wantStaticEntries
for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
@@ -572,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
@@ -647,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -691,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -753,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -823,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -904,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
}
}
-func TestNeighborCacheNotifiesWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- id, ok := s.Fetch(false /* block */)
- if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
- }
- if id != wakerID {
- t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
-func TestNeighborCacheRemoveWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
-
- // Remove the waker before the neighbor cache has the opportunity to send a
- // notification.
- neigh.removeWaker(entry.Addr, &w)
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Errorf("unexpected notification from waker with id %d", id)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
config := DefaultNUDConfigurations()
// Stay in Reachable so the cache can overflow
@@ -1059,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1072,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
@@ -1126,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) {
// Add a dynamic entry.
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1184,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) {
}
}
- // Clear shoud remove both dynamic and static entries.
+ // Clear should remove both dynamic and static entries.
neigh.clear()
// Remove events dispatched from clear() have no deterministic order so they
@@ -1231,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -1315,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
frequentlyUsedEntry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
// The following logic is very similar to overflowCache, but
@@ -1327,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
@@ -1370,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1378,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1432,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
}
// Expect to find only the frequently used entry and the most recent entries.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
@@ -1491,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) {
go func(entry NeighborEntry) {
defer wg.Done()
if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
- // Wait for all gorountines to send a request
+ // Wait for all goroutines to send a request
wg.Wait()
// Process all the requests for a single entry concurrently
@@ -1506,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than can fit in
// the cache. Our eviction strategy requires that the last entries are
// present, up to the size of the neighbor cache, and the rest are missing.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
var wantUnsortedEntries []NeighborEntry
for i := store.size() - neighborCacheSize; i < store.size(); i++ {
@@ -1544,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) {
// Add an entry
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err)
}
if t.Failed() {
t.FailNow()
@@ -1575,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1584,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
entry, ok := store.entry(1)
if !ok {
- t.Fatalf("store.entry(1) not found")
+ t.Fatal("store.entry(1) not found")
}
updatedLinkAddr = entry.LinkAddr
}
@@ -1601,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1609,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
@@ -1619,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) {
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1627,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1651,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
},
}
- // First, sanity check that resolution is working
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ // First, sanity check that resolution is working
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
- clock.Advance(typicalLatency)
+
got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1670,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
- // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ // Verify address resolution fails for an unknown address.
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1711,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+}
+
+// TestNeighborCacheRetryResolution simulates retrying communication after
+// failing to perform address resolution.
+func TestNeighborCacheRetryResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ // Simulate a faulty link.
+ dropReplies: true,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatal("store.entry(0) not found")
+ }
+
+ // Perform address resolution with a faulty link, which will fail.
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+ }
+
+ // Verify the entry is in Failed state.
+ wantEntries := []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Failed,
+ },
+ }
+ if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Retry address resolution with a working link.
+ linkRes.dropReplies = false
+ {
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ if incompleteEntry.State != Incomplete {
+ t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
+ }
+ clock.Advance(typicalLatency)
+
+ select {
+ case <-ch:
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err)
+ }
+ if reachableEntry.Addr != entry.Addr {
+ t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr)
+ }
+ if reachableEntry.LinkAddr != entry.LinkAddr {
+ t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr)
+ }
+ if reachableEntry.State != Reachable {
+ t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String())
+ }
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
}
@@ -1739,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
Addr: testEntryBroadcastAddr,
@@ -1747,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1772,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ b.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
- if doneCh != nil {
- <-doneCh
+
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 493e48031..75afb3001 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -19,7 +19,6 @@ import (
"sync"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -67,8 +66,7 @@ const (
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
- // Failed means traffic should not be sent to this neighbor since attempts of
- // reachability have returned inconclusive.
+ // Failed means recent attempts of reachability have returned inconclusive.
Failed
)
@@ -93,16 +91,13 @@ type neighborEntry struct {
neigh NeighborEntry
- // wakers is a set of waiters for address resolution result. Anytime state
- // transitions out of incomplete these waiters are notified. It is nil iff
- // address resolution is ongoing and no clients are waiting for the result.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil
- // iff nudState is not Reachable and address resolution is not yet in
- // progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
+
isRouter bool
job *tcpip.Job
}
@@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
}
}
-// addWaker adds w to the list of wakers waiting for address resolution.
-// Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
- if w == nil {
- return
- }
- if e.wakers == nil {
- e.wakers = make(map[*sleep.Waker]struct{})
- }
- e.wakers[w] = struct{}{}
-}
-
-// notifyWakersLocked notifies those waiting for address resolution, whether it
-// succeeded or failed. Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) notifyWakersLocked() {
- for w := range e.wakers {
- w.Assert()
+// notifyCompletionLocked notifies those waiting for address resolution, with
+// the link address if resolution completed successfully.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ for _, callback := range e.onResolve {
+ callback(e.neigh.LinkAddr, succeeded)
}
- e.wakers = nil
+ e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
@@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
@@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() {
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
@@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() {
// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
// has been removed.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
+// cancelJobLocked cancels the currently scheduled action, if there is one.
+// Entries in Unknown, Stale, or Static state do not have a scheduled action.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) cancelJobLocked() {
+ if job := e.job; job != nil {
+ job.Cancel()
+ }
+}
+
+// removeLocked prepares the entry for removal.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) removeLocked() {
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.dispatchRemoveEventLocked()
+ e.cancelJobLocked()
+ e.notifyCompletionLocked(false /* succeeded */)
+}
+
// setStateLocked transitions the entry to the specified state immediately.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
-// e.mu MUST be locked.
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
- // Cancel the previously scheduled action, if there is one. Entries in
- // Unknown, Stale, or Static state do not have scheduled actions.
- if timer := e.job; timer != nil {
- timer.Cancel()
- }
+ e.cancelJobLocked()
prev := e.neigh.State
e.neigh.State = next
@@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
e.job.Schedule(immediateDuration)
case Failed:
- e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&e.mu, func() {
- e.nic.neigh.removeEntryLocked(e)
- })
- e.job.Schedule(config.UnreachableTime)
+ e.notifyCompletionLocked(false /* succeeded */)
case Unknown, Stale, Static:
// Do nothing
@@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
+
+ fallthrough
case Unknown:
e.neigh.State = Incomplete
e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
@@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// implementation may find it convenient in some cases to return errors
// to the sender by taking the offending packet, generating an ICMP
// error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
+ // error-handling routines." - RFC 4861 section 2.1
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
@@ -347,9 +356,8 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
e.setStateLocked(Delay)
e.dispatchChangeEventLocked()
- case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
-
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -359,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// Neighbor Solicitation for ARP or NDP, respectively).
//
// Follows the logic defined in RFC 4861 section 7.2.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// Probes MUST be silently discarded if the target address is tentative, does
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
switch e.neigh.State {
- case Unknown, Incomplete, Failed:
+ case Unknown, Failed:
e.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
- e.notifyWakersLocked()
e.dispatchAddEventLocked()
+ case Incomplete:
+ // "If an entry already exists, and the cached link-layer address
+ // differs from the one in the received Source Link-Layer option, the
+ // cached address should be replaced by the received address, and the
+ // entry's reachability state MUST be set to STALE."
+ // - RFC 4861 section 7.2.3
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.setStateLocked(Stale)
+ e.notifyCompletionLocked(true /* succeeded */)
+ e.dispatchChangeEventLocked()
+
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
@@ -403,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// not be possible. SEND uses RSA key pairs to produce Cryptographically
// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
// claimed source of an NDP message is the owner of the claimed address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
switch e.neigh.State {
case Incomplete:
@@ -421,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
@@ -456,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
if !wasReachable {
e.dispatchChangeEventLocked()
}
@@ -494,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
// (e.g. TCP acknowledgements) reachability confirmation.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index c2b763325..ec34ffa5a 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -25,7 +25,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -73,35 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option {
// The following unit tests exercise every state transition and verify its
// behavior with RFC 4681.
//
-// | From | To | Cause | Action | Event |
-// | ========== | ========== | ========================================== | =============== | ======= |
-// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
-// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
-// | Unknown | Stale | Probe w/ unknown address | | Added |
-// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
-// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
-// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
-// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
-// | Reachable | Stale | Reachable timer expired | | Changed |
-// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
-// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
-// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet sent | | Changed |
-// | Delay | Reachable | Upper-layer confirmation | | Changed |
-// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
-// | Delay | Probe | Delay timer expired | Send probe | Changed |
-// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
-// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
-// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
-// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Failed | | Unreachability timer expired | Delete entry | |
+// | From | To | Cause | Update | Action | Event |
+// | ========== | ========== | ========================================== | ======== | ===========| ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added |
+// | Unknown | Stale | Probe | | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed |
+// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | |
+// | Reachable | Stale | Reachable timer expired | | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Stale | Stale | Override confirmation | LinkAddr | | Changed |
+// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed |
+// | Stale | Delay | Packet sent | | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | | Changed |
+// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Delay | Probe | Delay timer expired | | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed |
+// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Probe | Probe | Retransmit timer expired | | | Changed |
+// | Probe | Failed | Max probes sent without reply | | Notify | Removed |
+// | Failed | Incomplete | Packet queued | | Send probe | Added |
type testEntryEventType uint8
@@ -228,6 +228,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock: clock,
nudDisp: &disp,
},
+ stats: makeNICStats(),
}
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
@@ -256,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -289,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -318,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -365,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -404,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
updatedAtNanos := e.neigh.UpdatedAtNanos
e.mu.Unlock()
@@ -558,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) {
nudDisp.mu.Unlock()
}
-// TestEntryAddsAndClearsWakers verifies that wakers are added when
-// addWakerLocked is called and cleared when address resolution finishes. In
-// this case, address resolution will finish when transitioning from Incomplete
-// to Reachable.
-func TestEntryAddsAndClearsWakers(t *testing.T) {
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ }
e.mu.Unlock()
runImmediatelyScheduledJobs(clock)
@@ -591,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Lock()
- if got := e.wakers; got != nil {
- t.Errorf("got e.wakers = %v, want = nil", got)
- }
- e.addWakerLocked(&w)
- if got, want := w.IsAsserted(), false; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
- }
- if e.wakers == nil {
- t.Error("expected e.wakers to be non-nil")
- }
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
- IsRouter: false,
+ IsRouter: true,
})
- if e.wakers != nil {
- t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
- if got, want := w.IsAsserted(), true; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
}
e.mu.Unlock()
@@ -641,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -661,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- linkRes.mu.Unlock()
e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
+ Solicited: false,
Override: false,
- IsRouter: true,
+ IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
- }
- if !e.isRouter {
- t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -696,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
Entry: NeighborEntry{
Addr: entryTestAddr1,
LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ State: Stale,
},
},
}
@@ -707,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToStale(t *testing.T) {
+func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -734,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) {
}
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
+ e.handleProbeLocked(entryTestLinkAddr1)
if e.neigh.State != Stale {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
@@ -778,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -839,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
e.mu.Unlock()
}
@@ -883,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
Override: false,
IsRouter: true,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.isRouter, true; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
@@ -930,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
}
@@ -1081,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
}
@@ -2379,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
e.mu.Unlock()
@@ -2445,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.mu.Unlock()
}
@@ -2503,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2618,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2738,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2834,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2962,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -3099,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -3433,72 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryFailedGetsDeleted(t *testing.T) {
+func TestEntryFailedToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
- c.MaxUnicastProbes = 3
e, nudDisp, linkRes, clock := entryTestSetup(c)
- // Verify the cache contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
- t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
- }
-
+ // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
+ // their expected state.
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ }
e.mu.Unlock()
- runImmediatelyScheduledJobs(clock)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
+ }
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
- clock.Advance(waitFor)
- {
- wantProbes := []entryTestProbeInfo{
- // The next three probe are sent in Probe.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
+ e.mu.Unlock()
wantEvents := []testEntryEventInfo{
{
@@ -3511,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
},
},
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- },
- },
- {
- EventType: entryTestChanged,
+ EventType: entryTestRemoved,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
}
@@ -3552,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- // Verify the cache no longer contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
- t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
- }
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 60c81a3aa..4a34805b5 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -20,7 +20,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -54,18 +53,20 @@ type NIC struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ // packetEPs is protected by mu, but the contained packetEndpointList are
+ // not.
+ packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
-// NICStats includes transmitted and received stats.
+// NICStats hold statistics for a NIC.
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
DisabledRx DirectionStats
+
+ Neighbor NeighborStats
}
func makeNICStats() NICStats {
@@ -80,6 +81,39 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
+type packetEndpointList struct {
+ mu sync.RWMutex
+
+ // eps is protected by mu, but the contained PacketEndpoint values are not.
+ eps []PacketEndpoint
+}
+
+func (p *packetEndpointList) add(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.eps = append(p.eps, ep)
+}
+
+func (p *packetEndpointList) remove(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for i, epOther := range p.eps {
+ if epOther == ep {
+ p.eps = append(p.eps[:i], p.eps[i+1:]...)
+ break
+ }
+ }
+}
+
+// forEach calls fn with each endpoints in p while holding the read lock on p.
+func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ for _, ep := range p.eps {
+ fn(ep)
+ }
+}
+
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -100,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -123,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = new(packetEndpointList)
}
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
+ nic.mu.packetEPs[netNum] = new(packetEndpointList)
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
@@ -170,7 +204,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -182,6 +216,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.
@@ -232,7 +270,8 @@ func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) isPromiscuousMode() bool {
+// Promiscuous implements NetworkInterface.
+func (n *NIC) Promiscuous() bool {
n.mu.RLock()
rv := n.mu.promiscuous
n.mu.RUnlock()
@@ -255,16 +294,18 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// the same unresolved IP address, and transmit the saved
// packet when the address has been resolved.
//
- // RFC 4861 section 5.2 (for IPv6):
- // Once the IP address of the next-hop node is known, the sender
- // examines the Neighbor Cache for link-layer information about that
- // neighbor. If no entry exists, the sender creates one, sets its state
- // to INCOMPLETE, initiates Address Resolution, and then queues the data
- // packet pending completion of address resolution.
+ // RFC 4861 section 7.2.2 (for IPv6):
+ // While waiting for address resolution to complete, the sender MUST, for
+ // each neighbor, retain a small queue of packets waiting for address
+ // resolution to complete. The queue MUST hold at least one packet, and MAY
+ // contain more. However, the number of queued packets per neighbor SHOULD
+ // be limited to some small value. When a queue overflows, the new arrival
+ // SHOULD replace the oldest entry. Once address resolution completes, the
+ // node transmits any queued packets.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
- r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ r.Acquire()
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
@@ -276,9 +317,11 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ routeInfo: routeInfo{
+ NetProto: protocol,
+ },
}
+ r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
}
@@ -320,16 +363,21 @@ func (n *NIC) setSpoofing(enable bool) {
// primaryAddress returns an address that can be used to communicate with
// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
- n.mu.RLock()
- spoofing := n.mu.spoofing
- n.mu.RUnlock()
-
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
}
- return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+
+ return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
type getAddressBehaviour int
@@ -388,11 +436,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre
// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
// is passed to indicate whether or not we should generate temporary endpoints.
func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
- if ep, ok := n.networkEndpoints[protocol]; ok {
- return ep.AcquireAssignedAddress(address, createTemp, peb)
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return nil
}
- return nil
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb)
}
// addAddress adds a new address to n, so that it starts accepting packets
@@ -403,7 +457,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return tcpip.ErrUnknownProtocol
}
- addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -416,7 +475,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PermanentAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PermanentAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -427,7 +491,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PrimaryAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PrimaryAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -445,13 +514,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
return tcpip.AddressWithPrefix{}
}
- return ep.MainAddress()
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ return addressableEndpoint.MainAddress()
}
// removeAddress removes an address from n.
func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
for _, ep := range n.networkEndpoints {
- if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
continue
} else {
return err
@@ -469,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
return n.neigh.entries(), nil
}
-func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
- if n.neigh == nil {
- return
- }
-
- n.neigh.removeWaker(addr, w)
-}
-
func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
if n.neigh == nil {
return tcpip.ErrNotSupported
@@ -524,8 +595,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -541,11 +611,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
@@ -564,13 +630,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return false
}
-func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
- defer r.Release()
- r.PopulatePacketInfo(pkt)
- n.getNetworkEndpoint(protocol).HandlePacket(pkt)
-}
-
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the link endpoint.
@@ -592,7 +651,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
- netProto, ok := n.stack.networkProtocols[protocol]
+ networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -607,21 +666,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
- packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet type sockets that may be listening for all protocols.
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ protoEPs := n.mu.packetEPs[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ // Deliver to interested packet endpoints without holding NIC lock.
+ deliverPacketEPs := func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketHost
ep.HandlePacket(n.id, local, protocol, p)
}
-
- if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
- n.stack.stats.IP.PacketsReceived.Increment()
+ if protoEPs != nil {
+ protoEPs.forEach(deliverPacketEPs)
+ }
+ if anyEPs != nil {
+ anyEPs.forEach(deliverPacketEPs)
}
// Parse headers.
+ netProto := n.stack.NetworkProtocolInstance(protocol)
transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
// The packet is too small to contain a network header.
@@ -636,9 +700,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
-
if n.stack.handleLocal && !n.IsLoopback() {
+ src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View())
if r := n.getAddress(protocol, src); r != nil {
r.DecRef()
@@ -651,78 +714,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- // Loopback traffic skips the prerouting chain.
- if !n.IsLoopback() {
- // iptables filtering.
- ipt := n.stack.IPTables()
- address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
- // iptables is telling us to drop the packet.
- n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
- return
- }
- }
-
- if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil {
- n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt)
- return
- }
-
- // This NIC doesn't care about the packet. Find a NIC that cares about the
- // packet and forward it to the NIC.
- //
- // TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding(protocol) {
- r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
- if err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- return
- }
-
- // Found a NIC.
- n := r.localAddressNIC
- if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
- if n.isValidForOutgoing(addressEndpoint) {
- pkt.NICID = n.ID()
- r.RemoteAddress = src
- pkt.NetworkPacketInfo = r.networkPacketInfo()
- n.getNetworkEndpoint(protocol).HandlePacket(pkt)
- addressEndpoint.DecRef()
- r.Release()
- return
- }
-
- addressEndpoint.DecRef()
- }
-
- // n doesn't have a destination endpoint.
- // Send the packet out of n.
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease
- // the TTL field for ipv4/ipv6.
-
- // pkt may have set its header and may not have enough headroom for
- // link-layer header for the other link to prepend. Here we create a new
- // packet to forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
- // We need to do a deep copy of the IP packet because WritePacket (and
- // friends) take ownership of the packet buffer, but we do not own it.
- Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
- })
-
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
-
- r.Release()
- return
- }
-
- // If a packet socket handled the packet, don't treat it as invalid.
- if len(packetEPs) == 0 {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
+ networkEndpoint.HandlePacket(pkt)
}
// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
@@ -731,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ eps.forEach(func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
- }
+ })
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -893,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
if !ok {
return tcpip.ErrNotSupported
}
- n.mu.packetEPs[netProto] = append(eps, ep)
+ eps.add(ep)
return nil
}
@@ -906,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
if !ok {
return
}
-
- for i, epOther := range eps {
- if epOther == ep {
- n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
- return
- }
- }
+ eps.remove(ep)
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index ab629b3a4..12d67409a 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -109,14 +109,6 @@ const (
//
// Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
defaultMaxReachbilityConfirmations = 3
-
- // defaultUnreachableTime is the default duration for how long an entry will
- // remain in the FAILED state before being removed from the neighbor cache.
- //
- // Note, there is no equivalent protocol constant defined in RFC 4861. It
- // leaves the specifics of any garbage collection mechanism up to the
- // implementation.
- defaultUnreachableTime = 5 * time.Second
)
// NUDDispatcher is the interface integrators of netstack must implement to
@@ -278,10 +270,6 @@ type NUDConfigurations struct {
// TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
// configuration option is necessary.
MaxReachabilityConfirmations uint32
-
- // UnreachableTime describes how long an entry will remain in the FAILED
- // state before being removed from the neighbor cache.
- UnreachableTime time.Duration
}
// DefaultNUDConfigurations returns a NUDConfigurations populated with default
@@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations {
MaxUnicastProbes: defaultMaxUnicastProbes,
MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
- UnreachableTime: defaultUnreachableTime,
}
}
@@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() {
if c.MaxUnicastProbes == 0 {
c.MaxUnicastProbes = defaultMaxUnicastProbes
}
- if c.UnreachableTime == 0 {
- c.UnreachableTime = defaultUnreachableTime
- }
}
// calcMaxRandomFactor calculates the maximum value of the random factor used
@@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration {
s.config.BaseReachableTime != s.prevBaseReachableTime ||
s.config.MinRandomFactor != s.prevMinRandomFactor ||
s.config.MaxRandomFactor != s.prevMaxRandomFactor {
- return s.recomputeReachableTimeLocked()
+ s.recomputeReachableTimeLocked()
}
return s.reachableTime
}
@@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration {
// random value gets re-computed at least once every few hours.
//
// s.mu MUST be locked for writing.
-func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+func (s *NUDState) recomputeReachableTimeLocked() {
s.prevBaseReachableTime = s.config.BaseReachableTime
s.prevMinRandomFactor = s.config.MinRandomFactor
s.prevMaxRandomFactor = s.config.MaxRandomFactor
@@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
}
s.expiration = time.Now().Add(2 * time.Hour)
- return s.reachableTime
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 8cffb9fc6..7bca1373e 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -37,7 +37,6 @@ const (
defaultMaxUnicastProbes = 3
defaultMaxAnycastDelayTime = time.Second
defaultMaxReachbilityConfirmations = 3
- defaultUnreachableTime = 5 * time.Second
defaultFakeRandomNum = 0.5
)
@@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
}
}
-func TestNUDConfigurationsUnreachableTime(t *testing.T) {
- tests := []struct {
- name string
- unreachableTime time.Duration
- want time.Duration
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- unreachableTime: 0,
- want: defaultUnreachableTime,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- unreachableTime: time.Millisecond,
- want: time.Millisecond,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.UnreachableTime = test.unreachableTime
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- UseNeighborCache: true,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
- }
- if got := sc.UnreachableTime; got != test.want {
- t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
// TestNUDStateReachableTime verifies the correctness of the ReachableTime
// computation.
func TestNUDStateReachableTime(t *testing.T) {
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 5d364a2b0..4a3adcf33 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
for _, p := range packets {
if cancelled {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
- } else if _, err := p.route.Resolve(nil); err != nil {
+ } else if p.route.IsResolutionRequired() {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 00e9a82ae..4795208b4 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -65,10 +64,6 @@ const (
// NetworkPacketInfo holds information about a network layer packet.
type NetworkPacketInfo struct {
- // RemoteAddressBroadcast is true if the packet's remote address is a
- // broadcast address.
- RemoteAddressBroadcast bool
-
// LocalAddressBroadcast is true if the packet's local address is a broadcast
// address.
LocalAddressBroadcast bool
@@ -89,7 +84,7 @@ type TransportEndpoint interface {
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer)
+ HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -263,15 +258,6 @@ const (
PacketLoop
)
-// NetOptions is an interface that allows us to pass network protocol specific
-// options through the Stack layer code.
-type NetOptions interface {
- // AllocationSize returns the amount of memory that must be allocated to
- // hold the options given that the value must be rounded up to the next
- // multiple of 4 bytes.
- AllocationSize() int
-}
-
// NetworkHeaderParams are the header parameters given as input by the
// transport endpoint to the network.
type NetworkHeaderParams struct {
@@ -283,10 +269,6 @@ type NetworkHeaderParams struct {
// TOS refers to TypeOfService or TrafficClass field of the IP-header.
TOS uint8
-
- // Options is a set of options to add to a network header (or nil).
- // It will be protocol specific opaque information from higher layers.
- Options NetOptions
}
// GroupAddressableEndpoint is an endpoint that supports group addressing.
@@ -295,14 +277,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -518,6 +496,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+ // Promiscuous returns true if the interface is in promiscuous mode.
+ Promiscuous() bool
+
// WritePacketToRemote writes the packet to the given remote link address.
WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
@@ -525,8 +506,6 @@ type NetworkInterface interface {
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
- AddressableEndpoint
-
// Enable enables the endpoint.
//
// Must only be called when the stack is in a state that allows the endpoint
@@ -742,10 +721,6 @@ type LinkEndpoint interface {
// endpoint.
Capabilities() LinkEndpointCapabilities
- // WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header. It takes ownership of vv.
- WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
-
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
@@ -823,19 +798,26 @@ type LinkAddressCache interface {
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
- // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
- // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
- // registered with the network protocol, the cache attempts to resolve the address
- // and returns ErrWouldBlock. Waker is notified when address resolution is
- // complete (success or not).
+ // GetLinkAddress finds the link address corresponding to the remote address
+ // (e.g. IP -> MAC).
//
- // If address resolution is required, ErrNoLinkAddress and a notification channel is
- // returned for the top level caller to block. Channel is closed once address resolution
- // is complete (success or not).
- GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
-
- // RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ // Returns a link address for the remote address, if readily available.
+ //
+ // Returns ErrWouldBlock if the link address is not readily available, along
+ // with a notification channel for the caller to block on. Triggers address
+ // resolution asynchronously.
+ //
+ // If onResolve is provided, it will be called either immediately, if
+ // resolution is not required, or when address resolution is complete, with
+ // the resolved link address and whether resolution succeeded. After any
+ // callbacks have been called, the returned notification channel is closed.
+ //
+ // If specified, the local address must be an address local to the interface
+ // the neighbor cache belongs to. The local address is the source address of
+ // a packet prompting NUD/link address resolution.
+ //
+ // TODO(gvisor.dev/issue/5151): Don't return the link address.
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
}
// RawFactory produces endpoints for writing various types of raw packets.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 15ff437c7..b0251d0b4 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -17,20 +17,53 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
+ routeInfo
+
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
+
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+
+ // remoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ remoteLinkAddress tcpip.LinkAddress
+ }
+
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
+
+ // linkCache is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkCache LinkAddressCache
+
+ // linkRes is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkRes LinkAddressResolver
+}
+
+type routeInfo struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
- // RemoteLinkAddress is the link-layer (MAC) address of the
- // final destination of the route.
- RemoteLinkAddress tcpip.LinkAddress
-
// LocalAddress is the local address where the route starts.
LocalAddress tcpip.Address
@@ -46,47 +79,48 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+}
- // localAddressNIC is the interface the address is associated with.
- // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
- // address's assigned status without the NIC.
- localAddressNIC *NIC
-
- // localAddressEndpoint is the local address this route is associated with.
- localAddressEndpoint AssignableAddressEndpoint
-
- // outgoingNIC is the interface this route uses to write packets.
- outgoingNIC *NIC
+// RouteInfo contains all of Route's exported fields.
+type RouteInfo struct {
+ routeInfo
- // linkCache is set if link address resolution is enabled for this protocol on
- // the route's NIC.
- linkCache LinkAddressCache
+ // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ RemoteLinkAddress tcpip.LinkAddress
+}
- // linkRes is set if link address resolution is enabled for this protocol on
- // the route's NIC.
- linkRes LinkAddressResolver
+// GetFields returns a RouteInfo with all of r's exported fields. This allows
+// callers to store the route's fields without retaining a reference to it.
+func (r *Route) GetFields() RouteInfo {
+ return RouteInfo{
+ routeInfo: r.routeInfo,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
+ }
}
// constructAndValidateRoute validates and initializes a route. It takes
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
- addrWithPrefix := addressEndpoint.AddressWithPrefix()
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
+ if len(localAddr) == 0 {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ }
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
addressEndpoint.DecRef()
- return Route{}
+ return nil
}
// If no remote address is provided, use the local address.
if len(remoteAddr) == 0 {
- remoteAddr = addrWithPrefix.Address
+ remoteAddr = localAddr
}
r := makeRoute(
netProto,
- addrWithPrefix.Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -99,8 +133,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// broadcast it.
if len(gateway) > 0 {
r.NextHop = gateway
- } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.ResolveWith(header.EthernetBroadcastAddress)
}
return r
@@ -108,11 +142,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
+ if len(localAddr) == 0 {
+ localAddr = localAddressEndpoint.AddressWithPrefix().Address
+ }
+
loop := PacketOut
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
@@ -133,18 +171,23 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
- r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- localAddressEndpoint: localAddressEndpoint,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
+ routeInfo: routeInfo{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ Loop: loop,
+ },
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
}
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
@@ -159,7 +202,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -170,26 +213,12 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-// PopulatePacketInfo populates a packet buffer's packet information fields.
-//
-// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
-// the network layer.
-func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) {
- if r.local() {
- pkt.RXTransportChecksumValidated = true
- }
- pkt.NetworkPacketInfo = r.networkPacketInfo()
-}
-
-// networkPacketInfo returns the network packet information of the route.
-//
-// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
-// the network layer.
-func (r *Route) networkPacketInfo() NetworkPacketInfo {
- return NetworkPacketInfo{
- RemoteAddressBroadcast: r.IsOutboundBroadcast(),
- LocalAddressBroadcast: r.isInboundBroadcast(),
- }
+// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
+// the route.
+func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.mu.remoteLinkAddress
}
// NICID returns the id of the NIC from which this route originates.
@@ -253,22 +282,26 @@ func (r *Route) GSOMaxSize() uint32 {
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
-// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
-// notified when address resolution is complete (success or not).
-//
-// If address resolution is required, ErrNoLinkAddress and a notification channel is
-// returned for the top level caller to block. Channel is closed once address resolution
-// is complete (success or not).
+// Resolve attempts to resolve the link address if necessary.
//
-// The NIC r uses must not be locked.
-func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
- if !r.IsResolutionRequired() {
+// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
+// waiting for ARP reply). If address resolution is required, a notification
+// channel is also returned for the caller to block on. The channel is closed
+// once address resolution is complete (successful or not). If a callback is
+// provided, it will be called when address resolution is complete, regardless
+// of success or failure.
+func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
+ r.mu.Lock()
+
+ if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
+ r.mu.Unlock()
return nil, nil
}
@@ -276,7 +309,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
- r.RemoteLinkAddress = r.LocalLinkAddress
+ r.mu.remoteLinkAddress = r.LocalLinkAddress
+ r.mu.Unlock()
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -289,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
+ // Increment the route's reference count because finishResolution retains a
+ // reference to the route and releases it when called.
+ r.acquireLocked()
+ r.mu.Unlock()
+
+ finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
+ if ok {
+ r.ResolveWith(linkAddress)
+ }
+ if afterResolve != nil {
+ afterResolve()
+ }
+ r.Release()
+ }
+
if neigh := r.outgoingNIC.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
+ _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = entry.LinkAddr
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
+ _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = linkAddr
return nil, nil
}
-// RemoveWaker removes a waker that has been added in Resolve().
-func (r *Route) RemoveWaker(waker *sleep.Waker) {
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
- }
-
- if neigh := r.outgoingNIC.neigh; neigh != nil {
- neigh.removeWaker(nextAddr, waker)
- return
- }
-
- r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
-}
-
// local returns true if the route is a local route.
func (r *Route) local() bool {
return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
@@ -331,7 +363,13 @@ func (r *Route) local() bool {
//
// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isResolutionRequiredRLocked()
+}
+
+func (r *Route) isResolutionRequiredRLocked() bool {
+ if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
return false
}
@@ -339,11 +377,18 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isValidForOutgoing() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isValidForOutgoingRLocked()
+}
+
+func (r *Route) isValidForOutgoingRLocked() bool {
if !r.outgoingNIC.Enabled() {
return false
}
- if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
@@ -395,39 +440,31 @@ func (r *Route) MTU() uint32 {
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
-// Release frees all resources associated with the route.
+// Release decrements the reference counter of the resources associated with the
+// route.
func (r *Route) Release() {
- if r.localAddressEndpoint != nil {
- r.localAddressEndpoint.DecRef()
- r.localAddressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ ep.DecRef()
}
}
-// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.localAddressEndpoint != nil {
- if !r.localAddressEndpoint.IncRef() {
+// Acquire increments the reference counter of the resources associated with the
+// route.
+func (r *Route) Acquire() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ r.acquireLocked()
+}
+
+func (r *Route) acquireLocked() {
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ if !ep.IncRef() {
panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
}
}
- return *r
-}
-
-// MakeLoopedRoute duplicates the given route with special handling for routes
-// used for sending multicast or broadcast packets. In those cases the
-// multicast/broadcast address is the remote address when sending out, but for
-// incoming (looped) packets it becomes the local address. Similarly, the local
-// interface address that was the local address going out becomes the remote
-// address coming in. This is different to unicast routes where local and
-// remote addresses remain the same as they identify location (local vs remote)
-// not direction (source vs destination).
-func (r *Route) MakeLoopedRoute() Route {
- l := r.Clone()
- if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
- l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
- l.RemoteLinkAddress = l.LocalLinkAddress
- }
- return l
}
// Stack returns the instance of the Stack that owns this route.
@@ -440,7 +477,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.localAddressEndpoint.Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -450,27 +494,3 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
-
-// isInboundBroadcast returns true if the route is for an inbound broadcast
-// packet.
-func (r *Route) isInboundBroadcast() bool {
- // Only IPv4 has a notion of broadcast.
- return r.isV4Broadcast(r.LocalAddress)
-}
-
-// ReverseRoute returns new route with given source and destination address.
-func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
- return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- localAddressNIC: r.localAddressNIC,
- localAddressEndpoint: r.localAddressEndpoint,
- outgoingNIC: r.outgoingNIC,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0fe157128..114643b03 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -29,7 +29,6 @@ import (
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -82,6 +81,7 @@ type TCPRACKState struct {
FACK seqnum.Value
RTT time.Duration
Reord bool
+ DSACKSeen bool
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
@@ -170,6 +170,9 @@ type TCPSenderState struct {
// Outstanding is the number of packets in flight.
Outstanding int
+ // SackedOut is the number of packets which have been selectively acked.
+ SackedOut int
+
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
@@ -1080,7 +1083,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
Running: nic.Enabled(),
- Promiscuous: nic.isPromiscuousMode(),
+ Promiscuous: nic.Promiscuous(),
Loopback: nic.IsLoopback(),
}
nics[id] = NICInfo{
@@ -1117,6 +1120,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
+// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
+// the address prefix.
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+ ap := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: addr,
+ }
+ return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
+}
+
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
@@ -1207,10 +1220,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
- return Route{}, false
+ return nil
}
var outgoingNIC *NIC
@@ -1234,12 +1247,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// route.
if outgoingNIC == nil {
localAddressEndpoint.DecRef()
- return Route{}, false
+ return nil
}
r := makeLocalRoute(
netProto,
- localAddressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -1248,10 +1261,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
if r.IsOutboundBroadcast() {
r.Release()
- return Route{}, false
+ return nil
}
- return r, true
+ return r
}
// findLocalRouteRLocked returns a local route.
@@ -1260,26 +1273,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// is, a local route is a route where packets never have to leave the stack.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
if len(localAddr) == 0 {
localAddr = remoteAddr
}
if localAddressNICID == 0 {
for _, localAddressNIC := range s.nics {
- if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
- return r, true
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
}
}
- return Route{}, false
+ return nil
}
if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
}
- return Route{}, false
+ return nil
}
// FindRoute creates a route to the given destination address, leaving through
@@ -1293,7 +1306,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1304,7 +1317,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
if s.handleLocal && !isMulticast && !isLocalBroadcast {
- if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
return r, nil
}
}
@@ -1316,7 +1329,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
return makeRoute(
netProto,
- addressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
nic, /* outboundNIC */
nic, /* localAddressNIC*/
@@ -1328,9 +1341,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1353,8 +1366,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if needRoute {
gateway = route.Gateway
}
- r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
- if r == (Route{}) {
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
+ if r == nil {
panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
}
return r, nil
@@ -1390,13 +1403,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 {
if aNIC, ok := s.nics[id]; ok {
if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if id == 0 {
@@ -1408,7 +1421,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
@@ -1416,12 +1429,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1506,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
nic := s.nics[nicID]
if nic == nil {
@@ -1517,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve)
}
// Neighbors returns all IP to MAC address associations.
@@ -1533,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
return nic.neighbors()
}
-// RemoveWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
- if s.useNeighborCache {
- s.mu.RLock()
- nic, ok := s.nics[nicID]
- s.mu.RUnlock()
-
- if ok {
- nic.removeWaker(addr, waker)
- }
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic := s.nics[nicID]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.removeWaker(fullAddr, waker)
- }
-}
-
// AddStaticNeighbor statically associates an IP address to a MAC address.
func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
s.mu.RLock()
@@ -1809,49 +1799,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
nic.unregisterPacketEndpoint(netProto, ep)
}
-// WritePacket writes data directly to the specified NIC. It adds an ethernet
-// header based on the arguments.
-func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+// WritePacketToRemote writes a payload on the specified NIC using the provided
+// network protocol and remote link address.
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
}
-
- // Add our own fake ethernet header.
- ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
- DstAddr: dst,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- vv := buffer.View(fakeHeader).ToVectorisedView()
- vv.Append(payload)
-
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
- return err
- }
-
- return nil
-}
-
-// WriteRawPacket writes data directly to the specified NIC without adding any
-// headers.
-func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
- s.mu.Lock()
- nic, ok := s.nics[nicID]
- s.mu.Unlock()
- if !ok {
- return tcpip.ErrUnknownDevice
- }
-
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
- return err
- }
-
- return nil
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()),
+ Data: payload,
+ })
+ return nic.WritePacketToRemote(remote, nil, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
@@ -1911,7 +1872,6 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- // TODO: notify network of subscription via igmp protocol.
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2158,3 +2118,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
}
return protos
}
+
+func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ return subnet.IsBroadcast(addr)
+}
+
+// IsSubnetBroadcast returns true if the provided address is a subnet-local
+// broadcast address on the specified NIC and protocol.
+//
+// Returns false if the NIC is unknown or if the protocol is unknown or does
+// not support addressing.
+//
+// If the NIC is not specified, the stack will check all NICs.
+func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nicID != 0 {
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return false
+ }
+
+ return isSubnetBroadcastOnNIC(nic, protocol, addr)
+ }
+
+ for _, nic := range s.nics {
+ if isSubnetBroadcastOnNIC(nic, protocol, addr) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dedfdd435..856ebf6d4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,7 +27,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -112,7 +111,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
netHdr := pkt.NetworkHeader().View()
- f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
+
+ dst := tcpip.Address(netHdr[dstAddrOffset:][:1])
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ return
+ }
+ addressEndpoint.DecRef()
+
+ f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
@@ -159,9 +166,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.Clone()
- r.PopulatePacketInfo(pkt)
- f.HandlePacket(pkt)
+ f.HandlePacket(pkt.Clone())
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -401,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -419,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -430,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1557,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
}
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", got, want)
}
if gotRoute.NextHop != wantRoute.NextHop {
return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
@@ -1597,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = header.IPv4Any
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1651,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1661,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic2Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1677,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2214,88 +2231,6 @@ func TestNICStats(t *testing.T) {
}
}
-func TestNICForwarding(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const dstAddr = tcpip.Address("\x03")
-
- tests := []struct {
- name string
- headerLen uint16
- }{
- {
- name: "Zero header length",
- },
- {
- name: "Non-zero header length",
- headerLen: 16,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
- }
-
- ep2 := channelLinkWithHeaderLength{
- Endpoint: channel.New(10, defaultMTU, ""),
- headerLength: test.headerLen,
- }
- if err := s.CreateNIC(nicID2, &ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
- }
-
- // Route all packets to dstAddr to NIC 2.
- {
- subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
- }
-
- // Send a packet to dstAddr.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- pkt, ok := ep2.Read()
- if !ok {
- t.Fatal("packet not forwarded")
- }
-
- // Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
- t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
- }
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
// TestNICContextPreservation tests that you can read out via stack.NICInfo the
// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
func TestNICContextPreservation(t *testing.T) {
@@ -2483,9 +2418,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ AutoGenLinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
})},
}
@@ -2578,8 +2513,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ AutoGenLinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
})},
}
@@ -2612,9 +2547,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ AutoGenLinkLocal: true,
+ NDPDisp: &ndpDisp,
})},
}
@@ -2803,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- nicID = 1
- lifetimeSeconds = 9999
+ globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01")
+ ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02")
+ toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ nicID = 1
+ lifetimeSeconds = 9999
)
prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2821,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix
nicAddrs []tcpip.Address
slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix
- connectAddr tcpip.Address
+ remoteAddr tcpip.Address
expectedLocalAddr tcpip.Address
}{
- // Test Rule 1 of RFC 6724 section 5.
+ // Test Rule 1 of RFC 6724 section 5 (prefer same address).
{
name: "Same Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Same Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
- // Test Rule 2 of RFC 6724 section 5.
+ // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope).
{
name: "Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
+
+ // Test Rule 6 of 6724 section 5 (prefer matching label).
{
name: "Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
+ {
+ name: "Toredo most preferred (first address)",
+ nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "Toredo most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "6To4 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "6To4 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
- // Test Rule 7 of RFC 6724 section 5.
+ // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses).
{
name: "Temp Global most preferred (last address)",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
{
name: "Temp Global most preferred (first address)",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
slaacPrefixForTempAddrAfterNICAddrAdd: prefix1,
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
+ // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix).
+ {
+ name: "Longest prefix matched most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr2, globalAddr1},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+ {
+ name: "Longest prefix matched most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, globalAddr2},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+
// Test returning the endpoint that is closest to the front when
// candidate addresses are "equal" from the perspective of RFC 6724
// section 5.
{
name: "Unique Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Link Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local for Unique Local",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: uniqueLocalAddr2,
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Temp Global for Global",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
slaacPrefixForTempAddrAfterNICAddrAdd: prefix2,
- connectAddr: globalAddr1,
+ remoteAddr: globalAddr1,
expectedLocalAddr: tempGlobalAddr2,
},
}
@@ -2975,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr3,
- NIC: nicID,
- }})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) {
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
@@ -3000,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
t.FailNow()
}
- if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ }
+
+ addressableEndpoint, ok := netEP.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("network endpoint is not addressable")
+ }
+
+ addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */)
+ if addressEP == nil {
+ t.Fatal("expected a non-nil address endpoint")
+ }
+ defer addressEP.DecRef()
+
+ if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr {
t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
}
})
@@ -3427,11 +3432,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
remNetSubnetBcast := remNetSubnet.Broadcast()
tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedRoute stack.Route
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedLocalAddress tcpip.Address
+ expectedRemoteAddress tcpip.Address
+ expectedRemoteLinkAddress tcpip.LinkAddress
+ expectedNextHop tcpip.Address
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLoop stack.PacketLooping
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3446,14 +3456,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: ipv4SubnetBcast,
- RemoteLinkAddress: header.EthernetBroadcastAddress,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut | stack.PacketLoop,
- },
+ remoteAddr: ipv4SubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: ipv4SubnetBcast,
+ expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut | stack.PacketLoop,
},
// Broadcast to a locally attached /31 subnet does not populate the
// broadcast MAC.
@@ -3469,13 +3477,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix31.Address,
- RemoteAddress: ipv4Subnet31Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedLocalAddress: ipv4AddrPrefix31.Address,
+ expectedRemoteAddress: ipv4Subnet31Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a locally attached /32 subnet does not populate the
// broadcast MAC.
@@ -3491,13 +3497,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix32.Address,
- RemoteAddress: ipv4Subnet32Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedLocalAddress: ipv4AddrPrefix32.Address,
+ expectedRemoteAddress: ipv4Subnet32Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// IPv6 has no notion of a broadcast.
{
@@ -3512,13 +3516,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv6Addr.Address,
- RemoteAddress: ipv6SubnetBcast,
- NetProto: header.IPv6ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv6SubnetBcast,
+ expectedLocalAddress: ipv6Addr.Address,
+ expectedRemoteAddress: ipv6SubnetBcast,
+ expectedNetProto: header.IPv6ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a remote subnet in the route table is send to the next-hop
// gateway.
@@ -3535,14 +3537,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to an unknown subnet follows the default route. Note that this
// is essentially just routing an unknown destination IP, because w/o any
@@ -3560,14 +3560,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
}
@@ -3596,10 +3594,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
+ if err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
- t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ if r.LocalAddress != test.expectedLocalAddress {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress)
+ }
+ if r.RemoteAddress != test.expectedRemoteAddress {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress)
+ }
+ if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
+ t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
+ }
+ if r.NextHop != test.expectedNextHop {
+ t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop)
+ }
+ if r.NetProto != test.expectedNetProto {
+ t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto)
+ }
+ if r.Loop != test.expectedLoop {
+ t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop)
}
})
}
@@ -4167,10 +4182,12 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
if err != test.findRouteErr {
t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
}
- defer r.Release()
if test.findRouteErr != nil {
return
@@ -4228,3 +4245,63 @@ func TestFindRouteWithForwarding(t *testing.T) {
})
}
}
+
+func TestWritePacketToRemote(t *testing.T) {
+ const nicID = 1
+ const MTU = 1280
+ e := channel.New(1, MTU, linkAddr1)
+ s := stack.New(stack.Options{})
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("CreateNIC(%d) = %s", nicID, err)
+ }
+ tests := []struct {
+ name string
+ protocol tcpip.NetworkProtocolNumber
+ payload []byte
+ }{
+ {
+ name: "SuccessIPv4",
+ protocol: header.IPv4ProtocolNumber,
+ payload: []byte{1, 2, 3, 4},
+ },
+ {
+ name: "SuccessIPv6",
+ protocol: header.IPv6ProtocolNumber,
+ payload: []byte{5, 6, 7, 8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
+ }
+
+ pkt, ok := e.Read()
+ if got, want := ok, true; got != want {
+ t.Fatalf("e.Read() = %t, want %t", got, want)
+ }
+ if got, want := pkt.Proto, test.protocol; got != want {
+ t.Fatalf("pkt.Proto = %d, want %d", got, want)
+ }
+ if pkt.Route.RemoteLinkAddress != linkAddr2 {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
+ }
+ if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+
+ t.Run("InvalidNICID", func(t *testing.T) {
+ if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ }
+ pkt, ok := e.Read()
+ if got, want := ok, false; got != want {
+ t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
+ }
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index f183ec6e4..07b2818d2 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -182,7 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+// handleControlPacket delivers a control packet to the transport endpoint
+// identified by id.
func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -199,7 +200,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt)
+ selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 41a8e5ad0..859278f0b 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "io/ioutil"
"math"
"math/rand"
"testing"
@@ -141,11 +142,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -307,12 +308,9 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
- t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
- }
- bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
- if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
- t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
+ ep.SocketOptions().SetReusePort(endpoint.reuse)
+ if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
}
var dstAddr tcpip.Address
@@ -354,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
ep := <-pollChannel
- if _, _, err := ep.Read(nil); err != nil {
+ if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index c457b67a2..0ff32c6ea 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,12 +15,12 @@
package stack_test
import (
+ "io"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -39,14 +39,18 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
+ acceptQueue []*fakeTransportEndpoint
+
+ // ops is used to set and get socket options.
+ ops tcpip.SocketOptions
}
func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
@@ -59,8 +63,14 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
+ return &f.ops
+}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep.ops.InitHandler(ep)
+ return ep
}
func (f *fakeTransportEndpoint) Abort() {
@@ -68,6 +78,7 @@ func (f *fakeTransportEndpoint) Abort() {
}
func (f *fakeTransportEndpoint) Close() {
+ // TODO(gvisor.dev/issue/5153): Consider retaining the route.
f.route.Release()
}
@@ -75,8 +86,8 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return mask
}
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return buffer.View{}, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ return tcpip.ReadResult{}, nil
}
func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -100,30 +111,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// SetSockOpt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
// SetSockOptInt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -147,16 +144,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return tcpip.ErrNoRoute
}
- defer r.Release()
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
+ r.Release()
return err
}
- f.route = r.Clone()
+ f.route = r
return nil
}
@@ -186,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
}
a := f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
- return &a, nil, nil
+ return a, nil, nil
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
@@ -201,7 +198,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
); err != nil {
return err
}
- f.acceptQueue = []fakeTransportEndpoint{}
+ f.acceptQueue = []*fakeTransportEndpoint{}
return nil
}
@@ -227,7 +224,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
}
route.ResolveWith(pkt.SourceLinkAddress())
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -235,10 +232,12 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
proto: f.proto,
peerAddr: route.RemoteAddress,
route: route,
- })
+ }
+ ep.ops.InitHandler(ep)
+ f.acceptQueue = append(f.acceptQueue, ep)
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -553,87 +552,3 @@ func TestTransportOptions(t *testing.T) {
t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
}
}
-
-func TestTransportForwarding(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- // TODO(b/123449044): Change this to a channel NIC.
- ep1 := loopback.New()
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
- }
-
- // Route all packets to address 3 to NIC 2 and all packets to address
- // 1 to NIC 1.
- {
- subnet0, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- })
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Send a packet to address 1 from address 3.
- req := buffer.NewView(30)
- req[0] = 1
- req[1] = 3
- req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: req.ToVectorisedView(),
- }))
-
- aep, _, err := ep.Accept(nil)
- if err != nil || aep == nil {
- t.Fatalf("Accept failed: %v, %v", aep, err)
- }
-
- resp := buffer.NewView(30)
- if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- p, ok := ep2.Read()
- if !ok {
- t.Fatal("Response packet not forwarded")
- }
-
- nh := stack.PayloadSince(p.Pkt.NetworkHeader())
- if dst := nh[0]; dst != 3 {
- t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
- }
- if src := nh[1]; src != 1 {
- t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
- }
-}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 3ab2b7654..f798056c0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "io"
"math/bits"
"reflect"
"strconv"
@@ -39,7 +40,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -49,8 +49,9 @@ const ipv4AddressSize = 4
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
-// Note: to support save / restore, it is important that all tcpip errors have
-// distinct error messages.
+// All errors must have unique msg strings.
+//
+// +stateify savable
type Error struct {
msg string
@@ -112,6 +113,7 @@ var (
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
ErrMalformedHeader = &Error{msg: "header is malformed"}
+ ErrBadBuffer = &Error{msg: "bad buffer"}
)
var messageToError map[string]*Error
@@ -161,6 +163,7 @@ func StringToError(s string) *Error {
ErrNotPermitted,
ErrAddressFamilyNotSupported,
ErrMalformedHeader,
+ ErrBadBuffer,
}
messageToError = make(map[string]*Error)
@@ -247,6 +250,54 @@ func (a Address) WithPrefix() AddressWithPrefix {
}
}
+// Unspecified returns true if the address is unspecified.
+func (a Address) Unspecified() bool {
+ for _, b := range a {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// MatchingPrefix returns the matching prefix length in bits.
+//
+// Panics if b and a have different lengths.
+func (a Address) MatchingPrefix(b Address) uint8 {
+ const bitsInAByte = 8
+
+ if len(a) != len(b) {
+ panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b))
+ }
+
+ var prefix uint8
+ for i := range a {
+ aByte := a[i]
+ bByte := b[i]
+
+ if aByte == bByte {
+ prefix += bitsInAByte
+ continue
+ }
+
+ // Count the remaining matching bits in the byte from MSbit to LSBbit.
+ mask := uint8(1) << (bitsInAByte - 1)
+ for {
+ if aByte&mask == bByte&mask {
+ prefix++
+ mask >>= 1
+ continue
+ }
+
+ break
+ }
+
+ break
+ }
+
+ return prefix
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -447,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) {
return s[:size], nil
}
+var _ io.Writer = (*SliceWriter)(nil)
+
+// SliceWriter implements io.Writer for slices.
+type SliceWriter []byte
+
+// Write implements io.Writer.Write.
+func (s *SliceWriter) Write(b []byte) (int, error) {
+ n := copy(*s, b)
+ *s = (*s)[n:]
+ if n < len(b) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
// A ControlMessages contains socket control messages for IP sockets.
//
// +stateify savable
@@ -481,6 +547,17 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+
+ // HasOriginalDestinationAddress indicates whether OriginalDstAddress is
+ // set.
+ HasOriginalDstAddress bool
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress FullAddress
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr *SockError
}
// PacketOwner is used to get UID and GID of the packet.
@@ -492,6 +569,40 @@ type PacketOwner interface {
GID() uint32
}
+// ReadOptions contains options for Endpoint.Read.
+type ReadOptions struct {
+ // Peek indicates whether this read is a peek.
+ Peek bool
+
+ // NeedRemoteAddr indicates whether to return the remote address, if
+ // supported.
+ NeedRemoteAddr bool
+
+ // NeedLinkPacketInfo indicates whether to return the link-layer information,
+ // if supported.
+ NeedLinkPacketInfo bool
+}
+
+// ReadResult represents result for a successful Endpoint.Read.
+type ReadResult struct {
+ // Count is the number of bytes received and written to the buffer.
+ Count int
+
+ // Total is the number of bytes of the received packet. This can be used to
+ // determine whether the read is truncated.
+ Total int
+
+ // ControlMessages is the control messages received.
+ ControlMessages ControlMessages
+
+ // RemoteAddr is the remote address if ReadOptions.NeedAddr is true.
+ RemoteAddr FullAddress
+
+ // LinkPacketInfo is the link-layer information of the received packet if
+ // ReadOptions.NeedLinkPacketInfo is true.
+ LinkPacketInfo LinkPacketInfo
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -506,11 +617,15 @@ type Endpoint interface {
// Abort is best effort; implementing Abort with Close is acceptable.
Abort()
- // Read reads data from the endpoint and optionally returns the sender.
+ // Read reads data from the endpoint and optionally writes to dst.
+ //
+ // This method does not block if there is no data pending; in this case,
+ // ErrWouldBlock is returned.
//
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+ // If non-zero number of bytes are successfully read and written to dst, err
+ // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
+ // should be returned.
+ Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
@@ -532,11 +647,6 @@ type Endpoint interface {
// not). The channel is only non-nil in this case.
Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
- // Peek reads data without consuming it from the endpoint.
- //
- // This method does not block if there is no data pending.
- Peek([][]byte) (int64, ControlMessages, *Error)
-
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
//
@@ -593,10 +703,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt SettableSocketOption) *Error
- // SetSockOptBool sets a socket option, for simple cases where a value
- // has the bool type.
- SetSockOptBool(opt SockOptBool, v bool) *Error
-
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
SetSockOptInt(opt SockOptInt, v int) *Error
@@ -604,10 +710,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt GettableSocketOption) *Error
- // GetSockOptBool gets a socket option for simple cases where a return
- // value has the bool type.
- GetSockOptBool(SockOptBool) (bool, *Error)
-
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
GetSockOptInt(SockOptInt) (int, *Error)
@@ -634,6 +736,10 @@ type Endpoint interface {
// LastError clears and returns the last error reported by the endpoint.
LastError() *Error
+
+ // SocketOptions returns the structure which contains all the socket
+ // level options.
+ SocketOptions() *SocketOptions
}
// LinkPacketInfo holds Link layer information for a received packet.
@@ -647,17 +753,6 @@ type LinkPacketInfo struct {
PktType PacketType
}
-// PacketEndpoint are additional methods that are only implemented by Packet
-// endpoints.
-type PacketEndpoint interface {
- // ReadPacket reads a datagram/packet from the endpoint and optionally
- // returns the sender and additional LinkPacketInfo.
- //
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
-}
-
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -690,84 +785,6 @@ type WriteOptions struct {
Atomic bool
}
-// SockOptBool represents socket options which values have the bool type.
-type SockOptBool int
-
-const (
- // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether datagram sockets are allowed to send packets to a broadcast
- // address.
- BroadcastOption SockOptBool = iota
-
- // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be held until segments are full by the TCP transport
- // protocol.
- CorkOption
-
- // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
- // data should be sent out immediately by the transport protocol. For
- // TCP, it determines if the Nagle algorithm is on or off.
- DelayOption
-
- // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether TCP keepalive is enabled for this socket.
- KeepaliveEnabledOption
-
- // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether multicast packets sent over a non-loopback interface
- // will be looped back.
- MulticastLoopOption
-
- // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether UDP checksum is disabled for this socket.
- NoChecksumOption
-
- // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether SCM_CREDENTIALS socket control messages are enabled.
- //
- // Only supported on Unix sockets.
- PasscredOption
-
- // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
- QuickAckOption
-
- // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
- // specify if the IPV6_TCLASS ancillary message is passed with incoming
- // packets.
- ReceiveTClassOption
-
- // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
- // if the TOS ancillary message is passed with incoming packets.
- ReceiveTOSOption
-
- // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
- // specify if more inforamtion is provided with incoming packets such as
- // interface index and address.
- ReceiveIPPacketInfoOption
-
- // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
- // specify whether Bind() should allow reuse of local address.
- ReuseAddressOption
-
- // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
- // multiple sockets to be bound to an identical socket address.
- ReusePortOption
-
- // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
- // whether an IPv6 socket is to be restricted to sending and receiving
- // IPv6 packets only.
- V6OnlyOption
-
- // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
- // endpoint that all packets being written have an IP header and the
- // endpoint should not attach an IP header.
- IPHdrIncludedOption
-
- // AcceptConnOption is used by GetSockOptBool to indicate if the
- // socket is a listening socket.
- AcceptConnOption
-)
-
// SockOptInt represents socket options which values have the int type.
type SockOptInt int
@@ -977,14 +994,6 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
-// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
-// should bind only on a specific NIC.
-type BindToDeviceOption NICID
-
-func (*BindToDeviceOption) isGettableSocketOption() {}
-
-func (*BindToDeviceOption) isSettableSocketOption() {}
-
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
@@ -1159,14 +1168,6 @@ type RemoveMembershipOption MembershipOption
func (*RemoveMembershipOption) isSettableSocketOption() {}
-// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP out-of-band data is delivered along with the normal in-band data.
-type OutOfBandInlineOption int
-
-func (*OutOfBandInlineOption) isGettableSocketOption() {}
-
-func (*OutOfBandInlineOption) isSettableSocketOption() {}
-
// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
// classic BPF filter on a given endpoint.
type SocketDetachFilterOption int
@@ -1216,10 +1217,6 @@ type LingerOption struct {
Timeout time.Duration
}
-func (*LingerOption) isGettableSocketOption() {}
-
-func (*LingerOption) isSettableSocketOption() {}
-
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
@@ -1390,6 +1387,18 @@ type ICMPv6PacketStats struct {
// RedirectMsg is the total number of ICMPv6 redirect message packets
// counted.
RedirectMsg *StatCounter
+
+ // MulticastListenerQuery is the total number of Multicast Listener Query
+ // messages counted.
+ MulticastListenerQuery *StatCounter
+
+ // MulticastListenerReport is the total number of Multicast Listener Report
+ // messages counted.
+ MulticastListenerReport *StatCounter
+
+ // MulticastListenerDone is the total number of Multicast Listener Done
+ // messages counted.
+ MulticastListenerDone *StatCounter
}
// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
@@ -1431,6 +1440,10 @@ type ICMPv6SentPacketStats struct {
type ICMPv6ReceivedPacketStats struct {
ICMPv6PacketStats
+ // Unrecognized is the total number of ICMPv6 packets received that the
+ // transport layer does not know how to parse.
+ Unrecognized *StatCounter
+
// Invalid is the total number of ICMPv6 packets received that the
// transport layer could not parse.
Invalid *StatCounter
@@ -1440,33 +1453,102 @@ type ICMPv6ReceivedPacketStats struct {
RouterOnlyPacketsDroppedByHost *StatCounter
}
-// ICMPStats collects ICMP-specific stats (both v4 and v6).
-type ICMPStats struct {
+// ICMPv4Stats collects ICMPv4-specific stats.
+type ICMPv4Stats struct {
// ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
// and a single count of packets which failed to write to the link
// layer.
- V4PacketsSent ICMPv4SentPacketStats
+ PacketsSent ICMPv4SentPacketStats
// ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
// packet type and a single count of invalid packets received.
- V4PacketsReceived ICMPv4ReceivedPacketStats
+ PacketsReceived ICMPv4ReceivedPacketStats
+}
+// ICMPv6Stats collects ICMPv6-specific stats.
+type ICMPv6Stats struct {
// ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
// and a single count of packets which failed to write to the link
// layer.
- V6PacketsSent ICMPv6SentPacketStats
+ PacketsSent ICMPv6SentPacketStats
// ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
// packet type and a single count of invalid packets received.
- V6PacketsReceived ICMPv6ReceivedPacketStats
+ PacketsReceived ICMPv6ReceivedPacketStats
+}
+
+// ICMPStats collects ICMP-specific stats (both v4 and v6).
+type ICMPStats struct {
+ // V4 contains the ICMPv4-specifics stats.
+ V4 ICMPv4Stats
+
+ // V6 contains the ICMPv4-specifics stats.
+ V6 ICMPv6Stats
+}
+
+// IGMPPacketStats enumerates counts for all IGMP packet types.
+type IGMPPacketStats struct {
+ // MembershipQuery is the total number of Membership Query messages counted.
+ MembershipQuery *StatCounter
+
+ // V1MembershipReport is the total number of Version 1 Membership Report
+ // messages counted.
+ V1MembershipReport *StatCounter
+
+ // V2MembershipReport is the total number of Version 2 Membership Report
+ // messages counted.
+ V2MembershipReport *StatCounter
+
+ // LeaveGroup is the total number of Leave Group messages counted.
+ LeaveGroup *StatCounter
+}
+
+// IGMPSentPacketStats collects outbound IGMP-specific stats.
+type IGMPSentPacketStats struct {
+ IGMPPacketStats
+
+ // Dropped is the total number of IGMP packets dropped.
+ Dropped *StatCounter
+}
+
+// IGMPReceivedPacketStats collects inbound IGMP-specific stats.
+type IGMPReceivedPacketStats struct {
+ IGMPPacketStats
+
+ // Invalid is the total number of IGMP packets received that IGMP could not
+ // parse.
+ Invalid *StatCounter
+
+ // ChecksumErrors is the total number of IGMP packets dropped due to bad
+ // checksums.
+ ChecksumErrors *StatCounter
+
+ // Unrecognized is the total number of unrecognized messages counted, these
+ // are silently ignored for forward-compatibilty.
+ Unrecognized *StatCounter
+}
+
+// IGMPStats colelcts IGMP-specific stats.
+type IGMPStats struct {
+ // IGMPSentPacketStats contains counts of sent packets by IGMP packet type
+ // and a single count of invalid packets received.
+ PacketsSent IGMPSentPacketStats
+
+ // IGMPReceivedPacketStats contains counts of received packets by IGMP packet
+ // type and a single count of invalid packets received.
+ PacketsReceived IGMPReceivedPacketStats
}
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
// PacketsReceived is the total number of IP packets received from the
- // link layer in nic.DeliverNetworkPacket.
+ // link layer.
PacketsReceived *StatCounter
+ // DisabledPacketsReceived is the total number of IP packets received from the
+ // link layer when the IP layer is disabled.
+ DisabledPacketsReceived *StatCounter
+
// InvalidDestinationAddressesReceived is the total number of IP packets
// received with an unknown or invalid destination address.
InvalidDestinationAddressesReceived *StatCounter
@@ -1662,6 +1744,9 @@ type Stats struct {
// ICMP breaks out ICMP-specific stats (both v4 and v6).
ICMP ICMPStats
+ // IGMP breaks out IGMP-specific stats.
+ IGMP IGMPStats
+
// IP breaks out IP-specific stats (both v4 and v6).
IP IPStats
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 1c8e2bc34..9bd563c46 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -226,3 +226,87 @@ func TestAddressWithPrefixSubnet(t *testing.T) {
}
}
}
+
+func TestAddressUnspecified(t *testing.T) {
+ tests := []struct {
+ addr Address
+ unspecified bool
+ }{
+ {
+ addr: "",
+ unspecified: true,
+ },
+ {
+ addr: "\x00",
+ unspecified: true,
+ },
+ {
+ addr: "\x01",
+ unspecified: false,
+ },
+ {
+ addr: "\x00\x00",
+ unspecified: true,
+ },
+ {
+ addr: "\x01\x00",
+ unspecified: false,
+ },
+ {
+ addr: "\x00\x01",
+ unspecified: false,
+ },
+ {
+ addr: "\x01\x01",
+ unspecified: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) {
+ if got := test.addr.Unspecified(); got != test.unspecified {
+ t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified)
+ }
+ })
+ }
+}
+
+func TestAddressMatchingPrefix(t *testing.T) {
+ tests := []struct {
+ addrA Address
+ addrB Address
+ prefix uint8
+ }{
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x01",
+ prefix: 16,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x00",
+ prefix: 15,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x81\x00",
+ prefix: 0,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x01\x80",
+ prefix: 8,
+ },
+ {
+ addrA: "\x01\x01",
+ addrB: "\x02\x80",
+ prefix: 6,
+ },
+ }
+
+ for _, test := range tests {
+ if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix {
+ t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix)
+ }
+ }
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 9b0f3b675..ca1e88e99 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -15,16 +15,19 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/nested",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index bf7594268..4c2084d19 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -15,13 +15,16 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -31,6 +34,33 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil)
+var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil)
+
+// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner
+// link endpoint and checks the destination link address before delivering
+// network packets to the network dispatcher.
+//
+// See ethernet.Endpoint for more details.
+func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck {
+ var e endpointWithDestinationCheck
+ e.Endpoint.Init(ethernet.New(ep), &e)
+ return &e
+}
+
+// endpointWithDestinationCheck is a link endpoint that checks the destination
+// link address before delivering network packets to the network dispatcher.
+type endpointWithDestinationCheck struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+}
+
func TestForwarding(t *testing.T) {
const (
host1NICID = 1
@@ -209,16 +239,16 @@ func TestForwarding(t *testing.T) {
host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
- if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil {
+ if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
}
- if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil {
+ if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
}
- if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
@@ -229,19 +259,6 @@ func TestForwarding(t *testing.T) {
t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
-
if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
}
@@ -268,58 +285,58 @@ func TestForwarding(t *testing.T) {
}
host1Stack.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
- tcpip.Route{
+ {
Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
- tcpip.Route{
+ {
Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
NIC: host1NICID,
},
- tcpip.Route{
+ {
Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
NIC: host1NICID,
},
})
routerStack.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
NIC: routerNICID1,
},
- tcpip.Route{
+ {
Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
NIC: routerNICID1,
},
- tcpip.Route{
+ {
Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
NIC: routerNICID2,
},
- tcpip.Route{
+ {
Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
NIC: routerNICID2,
},
})
host2Stack.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
- tcpip.Route{
+ {
Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
- tcpip.Route{
+ {
Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
NIC: host2NICID,
},
- tcpip.Route{
+ {
Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
NIC: host2NICID,
@@ -366,24 +383,33 @@ func TestForwarding(t *testing.T) {
// Wait for the endpoint to be readable.
<-ch
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, len(data), opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if diff := cmp.Diff(v, buffer.View(data)); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ RemoteAddr: tcpip.FullAddress{Addr: expectedFrom},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != expectedFrom {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom)
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
}
- return addr
+ return res.RemoteAddr
}
addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index fe7c1bb3d..b4bffaec1 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -15,14 +15,14 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -87,21 +87,21 @@ func TestPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
- icmpBuf func(*testing.T) buffer.View
+ icmpBuf func(*testing.T) []byte
}{
{
name: "IPv4 Ping",
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) buffer.View {
+ icmpBuf: func(t *testing.T) []byte {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
hdr.SetType(header.ICMPv4Echo)
if n := copy(hdr.Payload(), data[:]); n != len(data) {
t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
}
- return buffer.View(hdr)
+ return hdr
},
},
{
@@ -109,14 +109,14 @@ func TestPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) buffer.View {
+ icmpBuf: func(t *testing.T) []byte {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
hdr.SetType(header.ICMPv6EchoRequest)
if n := copy(hdr.Payload(), data[:]); n != len(data) {
t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
}
- return buffer.View(hdr)
+ return hdr
},
},
}
@@ -133,20 +133,13 @@ func TestPing(t *testing.T) {
host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2)
- if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
- if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
- }
-
if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil {
t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err)
}
@@ -161,21 +154,21 @@ func TestPing(t *testing.T) {
}
host1Stack.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: ipv4Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
- tcpip.Route{
+ {
Destination: ipv6Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
})
host2Stack.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: ipv4Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
- tcpip.Route{
+ {
Destination: ipv6Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
@@ -208,16 +201,25 @@ func TestPing(t *testing.T) {
// Wait for the endpoint to be readable.
<-waiterCH
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, len(icmpBuf), opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err)
}
- if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != test.remoteAddr {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr)
+ if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
})
}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 421da1add..cb6169cfc 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -15,17 +15,20 @@
package integration_test
import (
+ "bytes"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -70,8 +73,8 @@ func TestInitialLoopbackAddresses(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDispatcher{},
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDispatcher{},
+ AutoGenLinkLocal: true,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
@@ -93,9 +96,10 @@ func TestInitialLoopbackAddresses(t *testing.T) {
}
}
-// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers
-// itself bound to all addresses in the subnet of an assigned address.
-func TestLoopbackAcceptAllInSubnet(t *testing.T) {
+// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers
+// itself bound to all addresses in the subnet of an assigned address and UDP
+// traffic is sent/received correctly.
+func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
const (
nicID = 1
localPort = 80
@@ -107,7 +111,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
Protocol: header.IPv4ProtocolNumber,
AddressWithPrefix: ipv4Addr,
}
- ipv4Bytes := []byte(ipv4Addr.Address)
+ ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
ipv4Bytes[len(ipv4Bytes)-1]++
otherIPv4Address := tcpip.Address(ipv4Bytes)
@@ -129,7 +133,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
{
name: "IPv4 bind to wildcard and send to assigned address",
addAddress: ipv4ProtocolAddress,
- dstAddr: ipv4Addr.Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
expectRx: true,
},
{
@@ -148,7 +152,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
name: "IPv4 bind to other subnet-local address and send to assigned address",
addAddress: ipv4ProtocolAddress,
bindAddr: otherIPv4Address,
- dstAddr: ipv4Addr.Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
expectRx: false,
},
{
@@ -161,7 +165,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
{
name: "IPv4 bind to assigned address and send to other subnet-local address",
addAddress: ipv4ProtocolAddress,
- bindAddr: ipv4Addr.Address,
+ bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
dstAddr: otherIPv4Address,
expectRx: false,
},
@@ -194,11 +198,11 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv4EmptySubnet,
NIC: nicID,
},
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
@@ -236,17 +240,28 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want)
}
- if gotPayload, _, err := rep.Read(nil); test.expectRx {
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ if res, err := rep.Read(&buf, len(data), opts); test.expectRx {
if err != nil {
- t.Fatalf("reep.Read(nil): %s", err)
+ t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{
+ Addr: test.addAddress.AddressWithPrefix.Address,
+ },
+ }, res,
+ checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"),
+ ); diff != "" {
+ t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff)
}
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ } else if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
}
})
}
@@ -276,7 +291,7 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv4EmptySubnet,
NIC: nicID,
},
@@ -312,3 +327,168 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
}
}
+
+// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers
+// itself bound to all addresses in the subnet of an assigned address and TCP
+// traffic is sent/received correctly.
+func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 80
+ )
+
+ ipv4ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8
+ ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
+ ipv4Bytes[len(ipv4Bytes)-1]++
+ otherIPv4Address := tcpip.Address(ipv4Bytes)
+
+ ipv6ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ }
+ ipv6Bytes := []byte(ipv6Addr.Address)
+ ipv6Bytes[len(ipv6Bytes)-1]++
+ otherIPv6Address := tcpip.Address(ipv6Bytes)
+
+ tests := []struct {
+ name string
+ addAddress tcpip.ProtocolAddress
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectAccept bool
+ }{
+ {
+ name: "IPv4 bind to wildcard and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: otherIPv4Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to wildcard send to other address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: remoteIPv4Addr,
+ expectAccept: false,
+ },
+ {
+ name: "IPv4 bind to other subnet-local address and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ expectAccept: false,
+ },
+ {
+ name: "IPv4 bind and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: otherIPv4Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv4 bind to assigned address and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
+ dstAddr: otherIPv4Address,
+ expectAccept: false,
+ },
+
+ {
+ name: "IPv6 bind and send to assigned address",
+ addAddress: ipv6ProtocolAddress,
+ bindAddr: ipv6Addr.Address,
+ dstAddr: ipv6Addr.Address,
+ expectAccept: true,
+ },
+ {
+ name: "IPv6 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv6ProtocolAddress,
+ dstAddr: otherIPv6Address,
+ expectAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer listeningEndpoint.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
+ if err := listeningEndpoint.Bind(bindAddr); err != nil {
+ t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err)
+ }
+
+ if err := listeningEndpoint.Listen(1); err != nil {
+ t.Fatalf("listeningEndpoint.Listen(1): %s", err)
+ }
+
+ connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer connectingEndpoint.Close()
+
+ connectAddr := tcpip.FullAddress{
+ Addr: test.dstAddr,
+ Port: localPort,
+ }
+ if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
+ }
+
+ if !test.expectAccept {
+ if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+ return
+ }
+
+ // Wait for the listening endpoint to be "readable". That is, wait for a
+ // new connection.
+ <-ch
+ var addr tcpip.FullAddress
+ if _, _, err := listeningEndpoint.Accept(&addr); err != nil {
+ t.Fatalf("listeningEndpoint.Accept(nil): %s", err)
+ }
+ if addr.Addr != test.addAddress.AddressWithPrefix.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 1eecd7957..b42375695 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -15,12 +15,14 @@
package integration_test
import (
+ "bytes"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -35,6 +37,9 @@ import (
const (
defaultMTU = 1280
ttl = 255
+
+ remotePort = 5555
+ localPort = 80
)
var (
@@ -96,11 +101,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{}))
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -151,21 +156,21 @@ func TestPingMulticastBroadcast(t *testing.T) {
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
- tcpip.Route{
+ {
Destination: header.IPv4EmptySubnet,
NIC: nicID,
},
@@ -215,163 +220,219 @@ func TestPingMulticastBroadcast(t *testing.T) {
}
+func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ totalLen := header.IPv4MinimumSize + payloadLen
+ hdr := buffer.NewPrependable(totalLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(udp.ProtocolNumber),
+ TTL: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
+func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLen),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
// multicast or broadcast address.
func TestIncomingMulticastAndBroadcast(t *testing.T) {
- const (
- nicID = 1
- remotePort = 5555
- localPort = 80
- )
+ const nicID = 1
data := []byte{1, 2, 3, 4}
- rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) {
- payloadLen := header.UDPMinimumSize + len(data)
- totalLen := header.IPv4MinimumSize + payloadLen
- hdr := buffer.NewPrependable(totalLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: remotePort,
- DstPort: localPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(udp.ProtocolNumber),
- TTL: ttl,
- SrcAddr: remoteIPv4Addr,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) {
- payloadLen := header.UDPMinimumSize + len(data)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: remotePort,
- DstPort: localPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
tests := []struct {
- name string
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectRx bool
+ name string
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
+ rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectRx bool
}{
{
- name: "IPv4 unicast binding to unicast",
- bindAddr: ipv4Addr.Address,
- dstAddr: ipv4Addr.Address,
- expectRx: true,
+ name: "IPv4 unicast binding to unicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4Addr.Address,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
},
{
- name: "IPv4 unicast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: ipv4Addr.Address,
- expectRx: false,
+ name: "IPv4 unicast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4Addr.Address,
+ expectRx: false,
},
{
- name: "IPv4 unicast binding to wildcard",
- dstAddr: ipv4Addr.Address,
- expectRx: true,
+ name: "IPv4 unicast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
},
{
- name: "IPv4 directed broadcast binding to subnet broadcast",
- bindAddr: ipv4SubnetBcast,
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 directed broadcast binding to subnet broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 directed broadcast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: ipv4SubnetBcast,
- expectRx: false,
+ name: "IPv4 directed broadcast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: false,
},
{
- name: "IPv4 directed broadcast binding to wildcard",
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 directed broadcast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 broadcast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: true,
+ name: "IPv4 broadcast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: true,
},
{
- name: "IPv4 broadcast binding to subnet broadcast",
- bindAddr: ipv4SubnetBcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: false,
+ name: "IPv4 broadcast binding to subnet broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: false,
},
{
- name: "IPv4 broadcast binding to wildcard",
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 broadcast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to all-systems multicast",
- bindAddr: header.IPv4AllSystems,
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
+ name: "IPv4 all-systems multicast binding to all-systems multicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4AllSystems,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to wildcard",
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
+ name: "IPv4 all-systems multicast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to unicast",
- bindAddr: ipv4Addr.Address,
- dstAddr: header.IPv4AllSystems,
- expectRx: false,
+ name: "IPv4 all-systems multicast binding to unicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4Addr.Address,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: false,
},
// IPv6 has no notion of a broadcast.
{
- name: "IPv6 unicast binding to wildcard",
- dstAddr: ipv6Addr.Address,
- expectRx: true,
+ name: "IPv6 unicast binding to wildcard",
+ dstAddr: ipv6Addr.Address,
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ expectRx: true,
},
{
- name: "IPv6 broadcast-like address binding to wildcard",
- dstAddr: ipv6SubnetBcast,
- expectRx: false,
+ name: "IPv6 broadcast-like address binding to wildcard",
+ dstAddr: ipv6SubnetBcast,
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ expectRx: false,
},
}
@@ -385,52 +446,41 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
- }
- ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
- }
-
- var netproto tcpip.NetworkProtocolNumber
- var rxUDP func(*channel.Endpoint, tcpip.Address)
- switch l := len(test.dstAddr); l {
- case header.IPv4AddressSize:
- netproto = header.IPv4ProtocolNumber
- rxUDP = rxIPv4UDP
- case header.IPv6AddressSize:
- netproto = header.IPv6ProtocolNumber
- rxUDP = rxIPv6UDP
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
+ protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
}
defer ep.Close()
bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%+v): %s", bindAddr, err)
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
}
- rxUDP(e, test.dstAddr)
- if gotPayload, _, err := ep.Read(nil); test.expectRx {
+ test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
+ var buf bytes.Buffer
+ var opts tcpip.ReadOptions
+ if res, err := ep.Read(&buf, len(data), opts); test.expectRx {
if err != nil {
- t.Fatalf("Read(nil): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
}
- if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
+ } else if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
}
})
}
@@ -476,11 +526,11 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
},
}
if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
// We use the empty subnet instead of just the loopback subnet so we
// also have a route to the IPv4 Broadcast address.
Destination: header.IPv4EmptySubnet,
@@ -510,23 +560,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
- }
-
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
- }
+ ep.SocketOptions().SetReuseAddress(true)
+ ep.SocketOptions().SetBroadcast(true)
bindAddr := tcpip.FullAddress{Port: localPort}
if bindWildcard {
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
}
} else {
bindAddr.Addr = test.broadcastAddr
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
}
}
@@ -552,9 +597,19 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
// Wait for the endpoint to become readable.
<-rep.ch
- if gotPayload, _, err := rep.ep.Read(nil); err != nil {
- t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
- } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ var buf bytes.Buffer
+ result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ if err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
+ continue
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff)
+ }
+ if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" {
t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
}
}
@@ -562,3 +617,153 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
})
}
}
+
+func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ data := []byte{1, 2, 3, 4}
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
+ rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
+ multicastAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast binding to unicast",
+ multicastAddr: "\xe0\x01\x02\x03",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ },
+ {
+ name: "IPv6 broadcast-like address binding to wildcard",
+ multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04",
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ specifyNICID bool
+ specifyNICAddr bool
+ }{
+ {
+ name: "Specify NIC ID and NIC address",
+ specifyNICID: true,
+ specifyNICAddr: true,
+ },
+ {
+ name: "Don't specify NIC ID or NIC address",
+ specifyNICID: false,
+ specifyNICAddr: false,
+ },
+ {
+ name: "Specify NIC ID but don't specify NIC address",
+ specifyNICID: true,
+ specifyNICAddr: false,
+ },
+ {
+ name: "Don't specify NIC ID but specify NIC address",
+ specifyNICID: false,
+ specifyNICAddr: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ }
+
+ // Set the route table so that UDP can find a NIC that is
+ // routable to the multicast address when the NIC isn't specified.
+ if !subTest.specifyNICID && !subTest.specifyNICAddr {
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+
+ memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr}
+ if subTest.specifyNICID {
+ memOpt.NIC = nicID
+ }
+ if subTest.specifyNICAddr {
+ memOpt.InterfaceAddr = test.localAddr.Address
+ }
+
+ // We should receive UDP packets to the group once we join the
+ // multicast group.
+ addOpt := tcpip.AddMembershipOption(memOpt)
+ if err := ep.SetSockOpt(&addOpt); err != nil {
+ t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
+ }
+ test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
+ var buf bytes.Buffer
+ result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("ep.Read: %s", err)
+ } else {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ // We should not receive UDP packets to the group once we leave
+ // the multicast group.
+ removeOpt := tcpip.RemoveMembershipOption(memOpt)
+ if err := ep.SetSockOpt(&removeOpt); err != nil {
+ t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
+ }
+ if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 02fc47015..52cf89b54 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -15,11 +15,14 @@
package integration_test
import (
+ "bytes"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -203,16 +206,25 @@ func TestLocalPing(t *testing.T) {
// Wait for the endpoint to become readable.
<-ch
- var addr tcpip.FullAddress
- v, _, err := ep.Read(&addr)
+ var buf bytes.Buffer
+ opts := tcpip.ReadOptions{NeedRemoteAddr: true}
+ res, err := ep.Read(&buf, math.MaxUint16, opts)
if err != nil {
- t.Fatalf("ep.Read(_): %s", err)
+ t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err)
}
- if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: test.localAddr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
}
- if addr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr)
+ if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
test.checkLinkEndpoint(t, e)
@@ -338,14 +350,27 @@ func TestLocalUDP(t *testing.T) {
<-serverCH
var clientAddr tcpip.FullAddress
- if v, _, err := server.Read(&clientAddr); err != nil {
+ var readBuf bytes.Buffer
+ if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("server.Read(_): %s", err)
} else {
- if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" {
- t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
+ clientAddr = read.RemoteAddr
+
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: readBuf.Len(),
+ Total: readBuf.Len(),
+ RemoteAddr: tcpip.FullAddress{
+ Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
+ },
+ }, read, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff)
}
- if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address {
- t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address)
+ if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" {
+ t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
@@ -367,15 +392,23 @@ func TestLocalUDP(t *testing.T) {
// Wait for the client endpoint to become readable.
<-clientCH
- var gotServerAddr tcpip.FullAddress
- if v, _, err := client.Read(&gotServerAddr); err != nil {
+ readBuf.Reset()
+ if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("client.Read(_): %s", err)
} else {
- if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" {
- t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: readBuf.Len(),
+ Total: readBuf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr},
+ }, read, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff)
}
- if gotServerAddr.Addr != serverAddr.Addr {
- t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr)
+ if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" {
+ t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 763cd8f84..c32fe5c4f 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,8 @@
package icmp
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -49,6 +51,7 @@ const (
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
@@ -71,18 +74,19 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return &endpoint{
+ ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -93,7 +97,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
sndBufSize: 32 * 1024,
state: stateInitial,
uniqueID: s.UniqueID(),
- }, nil
+ }
+ ep.ops.InitHandler(ep)
+ return ep, nil
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -126,7 +132,10 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
e.state = stateClosed
@@ -139,13 +148,13 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -155,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = p.senderAddress
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -264,26 +287,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- if to == nil {
- route = &e.route
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in Route.Resolve()
- // call below.
- e.mu.RUnlock()
- defer e.mu.RLock()
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // Recheck state after lock was re-acquired.
- if e.state != stateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
- }
- } else {
+ route := e.route
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -307,7 +312,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
}
if route.IsResolutionRequired() {
@@ -339,27 +344,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.SocketDetachFilterOption:
- return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- }
- return nil
-}
-
-// SetSockOptBool sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
@@ -375,17 +361,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -423,16 +398,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
@@ -548,7 +514,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
@@ -563,11 +528,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
e.ID = id
- e.route = r.Clone()
+ e.route = r
e.RegisterNICID = nicID
e.state = stateConnected
@@ -823,7 +789,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
@@ -853,3 +819,8 @@ func (*endpoint) Wait() {}
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 31831a6d8..3ab060751 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -26,6 +26,7 @@ package packet
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -60,6 +61,8 @@ type packet struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
@@ -83,12 +86,13 @@ type endpoint struct {
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
boundNIC tcpip.NICID
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// NewEndpoint returns a new packet endpoint.
@@ -104,6 +108,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
+ ep.ops.InitHandler(ep)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -156,8 +161,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.PacketEndpoint.ReadPacket.
-func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -169,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn
err = tcpip.ErrClosedForReceive
}
ep.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ if !opts.Peek {
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+ }
ep.rcvMu.Unlock()
- if addr != nil {
- *addr = packet.senderAddr
+ res := tcpip.ReadResult{
+ Total: packet.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: packet.timestampNS,
+ },
}
-
- if info != nil {
- *info = packet.packetInfo
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = packet.senderAddr
+ }
+ if opts.NeedLinkPacketInfo {
+ res.LinkPacketInfo = packet.packetInfo
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return ep.ReadPacket(addr, nil)
+ n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -199,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, tcpip.ErrInvalidOptionValue
}
-// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
func (*endpoint) Disconnect() *tcpip.Error {
@@ -300,26 +308,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- ep.mu.Lock()
- ep.linger = *v
- ep.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -373,28 +370,16 @@ func (ep *endpoint) LastError() *tcpip.Error {
return err
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- ep.mu.Lock()
- *o = ep.linger
- ep.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrNotSupported
- }
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (ep *endpoint) UpdateLastError(err *tcpip.Error) {
+ ep.lastErrorMu.Lock()
+ ep.lastError = err
+ ep.lastErrorMu.Unlock()
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.AcceptConnOption:
- return false, nil
- default:
- return false, tcpip.ErrNotSupported
- }
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ return tcpip.ErrNotSupported
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -548,4 +533,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &ep.ops
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7b6a87ba9..dd260535f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -58,12 +59,13 @@ type rawPacket struct {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
+
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
- hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -82,13 +84,14 @@ type endpoint struct {
bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -111,8 +114,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
- hdrIncluded: !associated,
}
+ e.ops.InitHandler(e)
+ e.ops.SetHeaderIncluded(!associated)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -167,9 +171,11 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- if e.connected {
+ e.connected = false
+
+ if e.route != nil {
e.route.Release()
- e.connected = false
+ e.route = nil
}
e.closed = true
@@ -185,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -197,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
pkt := e.rcvList.Front()
- e.rcvList.Remove(pkt)
- e.rcvBufSize -= pkt.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = pkt.senderAddr
+ res := tcpip.ReadResult{
+ Total: pkt.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: pkt.timestampNS,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = pkt.senderAddr
}
- return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
+ n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -220,6 +240,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrInvalidOptionValue
}
+ if opts.To != nil {
+ // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -249,24 +276,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
e.mu.RLock()
+ defer e.mu.RUnlock()
if e.closed {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidEndpointState
}
payloadBytes, err := p.FullPayload()
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
dstAddr := ip.DestinationAddress()
@@ -288,39 +313,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrDestinationRequired
}
- if e.route.IsResolutionRequired() {
- savedRoute := &e.route
- // Promote lock to exclusive if using a shared route,
- // given that it may need to change in finishWrite.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Make sure that the route didn't change during the
- // time we didn't hold the lock.
- if !e.connected || savedRoute != &e.route {
- e.mu.Unlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
- n, ch, err := e.finishWrite(payloadBytes, savedRoute)
- e.mu.Unlock()
- return n, ch, err
- }
-
- n, ch, err := e.finishWrite(payloadBytes, &e.route)
- e.mu.RUnlock()
- return n, ch, err
+ return e.finishWrite(payloadBytes, e.route)
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- e.mu.RUnlock()
return 0, nil, tcpip.ErrNoRoute
}
@@ -328,13 +330,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
- e.mu.RUnlock()
return 0, nil, err
}
- n, ch, err := e.finishWrite(payloadBytes, &route)
+ n, ch, err := e.finishWrite(payloadBytes, route)
route.Release()
- e.mu.RUnlock()
return n, ch, err
}
@@ -353,7 +353,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if e.hdrIncluded {
+ if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
@@ -378,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
return int64(len(payloadBytes)), nil, nil
}
-// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() *tcpip.Error {
return tcpip.ErrNotSupported
@@ -390,6 +385,11 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
+ if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -418,11 +418,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer route.Release()
if e.associated {
// Re-register the endpoint with the appropriate NIC.
if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ route.Release()
return err
}
e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
@@ -430,7 +430,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Save the route we've connected via.
- e.route = route.Clone()
+ e.route = route
e.connected = true
return nil
@@ -513,33 +513,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
+ switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- e.hdrIncluded = v
- e.mu.Unlock()
- return nil
- }
- return tcpip.ErrUnknownProtocolOption
-}
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
@@ -586,33 +568,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.mu.Lock()
- *o = e.linger
- e.mu.Unlock()
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
- return false, nil
-
- case tcpip.IPHdrIncludedOption:
- e.mu.Lock()
- v := e.hdrIncluded
- e.mu.Unlock()
- return v, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -647,6 +603,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ e.mu.RLock()
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -659,6 +616,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// sockets.
if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
return
@@ -666,6 +624,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
return
@@ -677,11 +636,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// If bound to a NIC, only accept data for that NIC.
if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
// If bound to an address, only accept data for that address.
if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
}
@@ -690,6 +651,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// connected to.
if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
+ e.mu.RUnlock()
return
}
@@ -724,6 +686,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
e.rcvMu.Unlock()
+ e.mu.RUnlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
if wasEmpty {
@@ -753,6 +716,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+// LastError implements tcpip.Endpoint.LastError.
func (*endpoint) LastError() *tcpip.Error {
return nil
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 7d97cbdc7..4a7e1c039 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is connected, re-connect.
if e.connected {
var err *tcpip.Error
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
+ // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
+ // remote address. We used to pass e.remote.RemoteAddress which was
+ // effectively the empty address but since moving e.route to hold a pointer
+ // to a route instead of the route by value, we pass the empty address
+ // directly. Obviously this was always wrong since we should provide the
+ // remote address we were connected to, to properly restore the route.
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 518449602..7e81203ba 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -45,7 +45,9 @@ go_library(
"rcv.go",
"rcv_state.go",
"reno.go",
+ "reno_recovery.go",
"sack.go",
+ "sack_recovery.go",
"sack_scoreboard.go",
"segment.go",
"segment_heap.go",
@@ -91,7 +93,7 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
- shard_count = 10,
+ shard_count = more_shards,
deps = [
":tcp",
"//pkg/rand",
@@ -110,6 +112,7 @@ go_test(
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/test/testutil",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 47982ca41..2d96a65bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
route.ResolveWith(s.remoteLinkAddr)
n := newEndpoint(l.stack, netProto, queue)
- n.v6only = l.v6Only
+ n.ops.SetV6Only(l.v6Only)
n.ID = s.id
n.boundNICID = s.nicID
n.route = route
@@ -235,11 +235,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
return n, nil
}
-// createEndpointAndPerformHandshake creates a new endpoint in connected state
-// and then performs the TCP 3-way handshake.
+// startHandshake creates a new endpoint in connecting state and then sends
+// the SYN-ACK for the TCP 3-way handshake. It returns the state of the
+// handshake in progress, which includes the new endpoint in the SYN-RCVD
+// state.
//
-// The new endpoint is returned with e.mu held.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+// On success, a handshake h is returned with h.ep.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -257,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
deferAccept := time.Duration(0)
if l.listenEP != nil {
- l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
- l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
@@ -278,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- l.listenEP.mu.Unlock()
- }
+ l.removePendingEndpoint(ep)
return nil, tcpip.ErrConnectionAborted
}
deferAccept = l.listenEP.deferAccept
- l.listenEP.mu.Unlock()
}
// Register new endpoint so that packets are routed to it.
@@ -306,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.isRegistered = true
- // Perform the 3-way handshake.
- h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
- if err := h.execute(); err != nil {
- ep.mu.Unlock()
- ep.Close()
- ep.notifyAborted()
-
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
- ep.drainClosingSegmentQueue()
-
+ // Initialize and start the handshake.
+ h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+ if err := h.start(); err != nil {
+ l.cleanupFailedHandshake(h)
return nil, err
}
- ep.isConnectNotified = true
+ return h, nil
+}
- // Update the receive window scaling. We can't do it before the
- // handshake because it's possible that the peer doesn't support window
- // scaling.
- ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+// performHandshake performs a TCP 3-way handshake. On success, the new
+// established endpoint is returned with e.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+ h, err := l.startHandshake(s, opts, queue, owner)
+ if err != nil {
+ return nil, err
+ }
+ ep := h.ep
+ if err := h.complete(); err != nil {
+ ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ ep.stats.FailedConnectionAttempts.Increment()
+ l.cleanupFailedHandshake(h)
+ return nil, err
+ }
+ l.cleanupCompletedHandshake(h)
return ep, nil
}
@@ -354,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() {
l.pending.Wait()
}
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupFailedHandshake(h *handshake) {
+ e := h.ep
+ e.mu.Unlock()
+ e.Close()
+ e.notifyAborted()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.drainClosingSegmentQueue()
+ e.h = nil
+}
+
+// cleanupCompletedHandshake transfers any state from the completed handshake to
+// the new endpoint.
+//
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
+ e := h.ep
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.isConnectNotified = true
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ e.rcv.rcvWndScale = e.h.effectiveRcvWndScale()
+
+ // Clean up handshake state stored in the endpoint so that it can be GCed.
+ e.h = nil
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
@@ -433,23 +469,40 @@ func (e *endpoint) notifyAborted() {
//
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer ctx.synRcvdCount.dec()
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error {
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.decSynRcvdCount()
- return
+ e.synRcvdCount--
+ return err
}
- ctx.removePendingEndpoint(n)
- e.decSynRcvdCount()
- n.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(n)
+ go func() {
+ defer ctx.synRcvdCount.dec()
+ if err := h.complete(); err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+ e.deliverAccepted(h.ep)
+ }() // S/R-SAFE: synRcvdCount is the barrier.
+
+ return nil
}
func (e *endpoint) incSynRcvdCount() bool {
@@ -462,12 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool {
return canInc
}
-func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
-}
-
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
@@ -477,6 +524,8 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
@@ -500,7 +549,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// backlog.
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
- go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ _ = e.handleSynSegment(ctx, s, &opts)
return nil
}
ctx.synRcvdCount.dec()
@@ -550,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
}
- if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
return err
}
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
@@ -703,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
- v6Only := e.v6only
+ v6Only := e.ops.GetV6Only()
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
@@ -712,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
// to the endpoint.
e.setEndpointState(StateClose)
- // close any endpoints in SYN-RCVD state.
+ // Close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
// Do cleanup if needed.
@@ -729,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6e9015be1..a00ef97c6 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -16,6 +16,7 @@ package tcp
import (
"encoding/binary"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/rand"
@@ -102,21 +103,26 @@ type handshake struct {
// been received. This is required to stop retransmitting the
// original SYN-ACK when deferAccept is enabled.
acked bool
+
+ // sendSYNOpts is the cached values for the SYN options to be sent.
+ sendSYNOpts header.TCPSynOptions
}
-func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- h := handshake{
- ep: ep,
+func (e *endpoint) newHandshake() *handshake {
+ h := &handshake{
+ ep: e,
active: true,
- rcvWnd: rcvWnd,
- rcvWndScale: ep.rcvWndScaleForHandshake(),
+ rcvWnd: seqnum.Size(e.initialReceiveWindow()),
+ rcvWndScale: e.rcvWndScaleForHandshake(),
}
h.resetState()
+ // Store reference to handshake state in endpoint.
+ e.h = h
return h
}
-func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
- h := newHandshake(ep, rcvWnd)
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+ h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
}
@@ -128,7 +134,7 @@ func FindWndScale(wnd seqnum.Size) int {
return 0
}
- max := seqnum.Size(0xffff)
+ max := seqnum.Size(math.MaxUint16)
s := 0
for wnd > max && s < header.MaxWndScale {
s++
@@ -295,7 +301,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
if ttl == 0 {
ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -356,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -456,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error {
func (h *handshake) resolveRoute() *tcpip.Error {
// Set up the wakers.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
@@ -464,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error {
// Initial action is to resolve route.
index := wakerForResolution
+ attemptedResolution := false
for {
switch index {
case wakerForResolution:
- if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
- if err == tcpip.ErrNoLinkAddress {
- h.ep.stats.SendErrors.NoLinkAddr.Increment()
- } else if err != nil {
+ if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock {
+ if err != nil {
h.ep.stats.SendErrors.NoRoute.Increment()
}
// Either success (err == nil) or failure.
return err
}
+ if attemptedResolution {
+ h.ep.stats.SendErrors.NoLinkAddr.Increment()
+ return tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
// Resolution not completed. Keep trying...
case wakerForNotification:
n := h.ep.fetchNotifications()
if n&notifyClose != 0 {
- h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -491,7 +500,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
}
@@ -502,8 +511,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
}
-// execute executes the TCP 3-way handshake.
-func (h *handshake) execute() *tcpip.Error {
+// start resolves the route if necessary and sends the first
+// SYN/SYN-ACK.
+func (h *handshake) start() *tcpip.Error {
if h.ep.route.IsResolutionRequired() {
if err := h.resolveRoute(); err != nil {
return err
@@ -511,19 +521,7 @@ func (h *handshake) execute() *tcpip.Error {
}
h.startTime = time.Now()
- // Initialize the resend timer.
- resendWaker := sleep.Waker{}
- timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, resendWaker.Assert)
- defer rt.Stop()
-
- // Set up the wakers.
- s := sleep.Sleeper{}
- s.AddWaker(&resendWaker, wakerForResend)
- s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
- s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
- defer s.Done()
-
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
@@ -531,10 +529,6 @@ func (h *handshake) execute() *tcpip.Error {
sackEnabled = false
}
- // Send the initial SYN segment and loop until the handshake is
- // completed.
- h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
-
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
@@ -544,9 +538,8 @@ func (h *handshake) execute() *tcpip.Error {
MSS: h.ep.amss,
}
- // Execute is also called in a listen context so we want to make sure we
- // only send the TS/SACK option when we received the TS/SACK in the
- // initial SYN.
+ // start() is also called in a listen context so we want to make sure we only
+ // send the TS/SACK option when we received the TS/SACK in the initial SYN.
if h.state == handshakeSynRcvd {
synOpts.TS = h.ep.sendTSOk
synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
@@ -557,7 +550,8 @@ func (h *handshake) execute() *tcpip.Error {
}
}
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.sendSYNOpts = synOpts
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -566,7 +560,25 @@ func (h *handshake) execute() *tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, synOpts)
+ return nil
+}
+
+// complete completes the TCP 3-way handshake initiated by h.start().
+func (h *handshake) complete() *tcpip.Error {
+ // Set up the wakers.
+ var s sleep.Sleeper
+ resendWaker := sleep.Waker{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+ // Initialize the resend timer.
+ timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert)
+ if err != nil {
+ return err
+ }
+ defer timer.stop()
for h.state != handshakeCompleted {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
// throughout handshake processing).
@@ -576,11 +588,9 @@ func (h *handshake) execute() *tcpip.Error {
switch index {
case wakerForResend:
- timeOut *= 2
- if timeOut > MaxRTO {
- return tcpip.ErrTimeout
+ if err := timer.reset(); err != nil {
+ return err
}
- rt.Reset(timeOut)
// Resend the SYN/SYN-ACK only if the following conditions hold.
// - It's an active handshake (deferAccept does not apply)
// - It's a passive handshake and we have not yet got the final-ACK.
@@ -590,7 +600,7 @@ func (h *handshake) execute() *tcpip.Error {
// the connection with another ACK or data (as ACKs are never
// retransmitted on their own).
if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
- h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ h.ep.sendSynTCP(h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -598,7 +608,7 @@ func (h *handshake) execute() *tcpip.Error {
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
- }, synOpts)
+ }, h.sendSYNOpts)
}
case wakerForNotification:
@@ -624,9 +634,8 @@ func (h *handshake) execute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.LastError()
+ return h.ep.lastErrorLocked()
}
-
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
return err
@@ -637,6 +646,34 @@ func (h *handshake) execute() *tcpip.Error {
return nil
}
+type backoffTimer struct {
+ timeout time.Duration
+ maxTimeout time.Duration
+ t *time.Timer
+}
+
+func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) {
+ if timeout > maxTimeout {
+ return nil, tcpip.ErrTimeout
+ }
+ bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
+ bt.t = time.AfterFunc(timeout, f)
+ return bt, nil
+}
+
+func (bt *backoffTimer) reset() *tcpip.Error {
+ bt.timeout *= 2
+ if bt.timeout > MaxRTO {
+ return tcpip.ErrTimeout
+ }
+ bt.t.Reset(bt.timeout)
+ return nil
+}
+
+func (bt *backoffTimer) stop() {
+ bt.t.Stop()
+}
+
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
if synOpts.TS {
@@ -785,8 +822,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
data = data.Clone(nil)
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
mss := int(gso.MSS)
@@ -830,8 +867,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// network endpoint and under the provided identity.
func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
optLen := len(tf.opts)
- if tf.rcvWnd > 0xffff {
- tf.rcvWnd = 0xffff
+ if tf.rcvWnd > math.MaxUint16 {
+ tf.rcvWnd = math.MaxUint16
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
@@ -906,7 +943,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, tcpFields{
+ err := e.sendTCP(e.route, tcpFields{
id: e.ID,
ttl: e.ttl,
tos: e.sendTOS,
@@ -967,7 +1004,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
@@ -1045,7 +1082,7 @@ func (e *endpoint) transitionToStateCloseLocked() {
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
- if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
@@ -1106,7 +1143,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
- e.HardError = tcpip.ErrAborted
+ e.hardError = tcpip.ErrAborted
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@@ -1251,7 +1288,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
userTimeout := e.userTimeout
e.keepalive.Lock()
- if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ if !e.SocketOptions().GetKeepAlive() || !e.keepalive.timer.checkExpiration() {
e.keepalive.Unlock()
return nil
}
@@ -1288,7 +1325,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
}
// Start the keepalive timer IFF it's enabled and there is no pending
// data to send.
- if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
e.keepalive.timer.disable()
e.keepalive.Unlock()
return
@@ -1318,9 +1355,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
epilogue := func() {
// e.mu is expected to be hold upon entering this section.
-
if e.snd != nil {
e.snd.resendTimer.cleanup()
+ e.snd.rc.probeTimer.cleanup()
}
if closeTimer != nil {
@@ -1342,20 +1379,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if handshake {
- // This is an active connection, so we must initiate the 3-way
- // handshake, and then inform potential waiters about its
- // completion.
- initialRcvWnd := e.initialReceiveWindow()
- h := newHandshake(e, seqnum.Size(initialRcvWnd))
- h.ep.setEndpointState(StateSynSent)
-
- if err := h.execute(); err != nil {
+ if err := e.h.complete(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
- e.HardError = err
+ e.hardError = err
e.workerCleanup = true
// Lock released below.
@@ -1364,9 +1394,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
}
- e.keepalive.timer.init(&e.keepalive.waker)
- defer e.keepalive.timer.cleanup()
-
drained := e.drainDone != nil
if drained {
close(e.drainDone)
@@ -1411,6 +1438,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
},
{
+ w: &e.snd.rc.probeWaker,
+ f: e.snd.probeTimerExpired,
+ },
+ {
w: &e.newSegmentWaker,
f: func() *tcpip.Error {
return e.handleSegments(false /* fastPath */)
@@ -1489,7 +1520,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
// Initialize the sleeper based on the wakers in funcs.
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
for i := range funcs {
s.AddWaker(funcs[i].w, i)
}
@@ -1613,7 +1644,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
}
extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
if newSyn {
- info := e.EndpointInfo.TransportEndpointInfo
+ info := e.TransportEndpointInfo
newID := info.ID
newID.RemoteAddress = ""
newID.RemotePort = 0
@@ -1676,7 +1707,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
const notification = 2
const timeWaitDone = 3
- s := sleep.Sleeper{}
+ var s sleep.Sleeper
defer s.Done()
s.AddWaker(&e.newSegmentWaker, newSegment)
s.AddWaker(&e.notificationWaker, notification)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index a6f25896b..1d1b01a6c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -405,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) {
}
}
- // Make sure we get the same error when calling the original ep and the
- // new one. This validates that v4-mapped endpoints are still able to
- // query the V6Only flag, whereas pure v4 endpoints are not.
- _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
- t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
- }
-
// Check the peer address.
addr, err := nep.GetRemoteAddress()
if err != nil {
@@ -530,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
- nep, _, err := c.EP.Accept(&addr)
+ _, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept(&addr)
+ _, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -548,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) {
if addr.Addr != context.TestV6Addr {
t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
}
-
- // Make sure we can still query the v6 only status of the new endpoint,
- // that is, that it is in fact a v6 socket.
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
- }
}
func TestV4AcceptOnV4(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 23b9de8c5..25b180fa5 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"encoding/binary"
"fmt"
+ "io"
"math"
"runtime"
"strings"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
@@ -310,16 +310,12 @@ type Stats struct {
func (*Stats) IsEndpointStats() {}
// EndpointInfo holds useful information about a transport endpoint which
-// can be queried by monitoring tools.
+// can be queried by monitoring tools. This exists to allow tcp-only state to
+// be exposed.
//
// +stateify savable
type EndpointInfo struct {
stack.TransportEndpointInfo
-
- // HardError is meaningful only when state is stateError. It stores the
- // error to be returned when read/write syscalls are called and the
- // endpoint is in this state. HardError is protected by endpoint mu.
- HardError *tcpip.Error `state:".(string)"`
}
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -367,6 +363,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
// +stateify savable
type endpoint struct {
EndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// endpointEntry is used to queue endpoints for processing to the
// a given tcp processor goroutine.
@@ -386,20 +383,38 @@ type endpoint struct {
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
+ // hardError is meaningful only when state is stateError. It stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by endpoint mu.
+ hardError *tcpip.Error `state:".(string)"`
+
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
- // The following fields are used to manage the receive queue. The
- // protocol goroutine adds ready-for-delivery segments to rcvList,
- // which are returned by Read() calls to users.
+ // rcvReadMu synchronizes calls to Read.
//
- // Once the peer has closed its send side, rcvClosed is set to true
- // to indicate to users that no more data is coming.
+ // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // must be held during each read to ensure atomicity, so that multiple reads
+ // do not interleave.
+ //
+ // rcvReadMu should be held before holding mu.
+ rcvReadMu sync.Mutex `state:"nosave"`
+
+ // rcvListMu synchronizes access to rcvList.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
+ rcvListMu sync.Mutex `state:"nosave"`
+
+ // rcvList is the queue for ready-for-delivery segments.
+ //
+ // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
+ // and removing segments from list. A range of segment can be determined, then
+ // temporarily release mu and rcvListMu while processing the segment range.
+ // This allows new segments to be appended to the list while processing.
+ //
+ // rcvListMu must be held to append segments to list.
rcvList segmentList `state:"wait"`
rcvClosed bool
// rcvBufSize is the total size of the receive buffer.
@@ -421,7 +436,10 @@ type endpoint struct {
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
- mu sync.Mutex `state:"nosave"`
+ //
+ // During handshake, mu is locked by the protocol listen goroutine and
+ // released by the handshake completion goroutine.
+ mu sync.CrossGoroutineMutex `state:"nosave"`
ownedByUser uint32
// state must be read/set using the EndpointState()/setEndpointState()
@@ -436,13 +454,14 @@ type endpoint struct {
isPortReserved bool `state:"manual"`
isRegistered bool `state:"manual"`
boundNICID tcpip.NICID
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
ttl uint8
- v6only bool
isConnectNotified bool
- // TCP should never broadcast but Linux nevertheless supports enabling/
- // disabling SO_BROADCAST, albeit as a NOOP.
- broadcast bool
+
+ // h stores a reference to the current handshake state if the endpoint is in
+ // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep.
+ // nil otherwise.
+ h *handshake `state:"nosave"`
// portFlags stores the current values of port related flags.
portFlags ports.Flags
@@ -489,6 +508,9 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
+ // tcpRecovery is the loss deteoction algorithm used by TCP.
+ tcpRecovery tcpip.TCPRecovery
+
// sackPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
sackPermitted bool
@@ -496,32 +518,14 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // bindToDevice is set to the NIC on which to bind or disabled if 0.
- bindToDevice tcpip.NICID
-
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
delay uint32
- // cork holds back segments until full.
- //
- // cork is a boolean (0 is false) and must be accessed atomically.
- cork uint32
-
// scoreboard holds TCP SACK Scoreboard information for this endpoint.
scoreboard *SACKScoreboard
- // The options below aren't implemented, but we remember the user
- // settings because applications expect to be able to set/query these
- // options.
-
- // slowAck holds the negated state of quick ack. It is stubbed out and
- // does nothing.
- //
- // slowAck is a boolean (0 is false) and must be accessed atomically.
- slowAck uint32
-
// segmentQueue is used to hand received segments to the protocol
// goroutine. Segments are queued as long as the queue is not full,
// and dropped when it is.
@@ -683,8 +687,8 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -696,7 +700,7 @@ func (e *endpoint) UniqueID() uint64 {
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
// r, it will be used; otherwise, the maximum possible MSS will be used.
-func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
// TODO(b/143359391): Respect TCP Min and Max size.
maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
@@ -845,7 +849,6 @@ func (e *endpoint) recentTimestamp() uint32 {
// +stateify savable
type keepalive struct {
sync.Mutex `state:"nosave"`
- enabled bool
idle time.Duration
interval time.Duration
count int
@@ -879,6 +882,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
+ e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
+ e.ops.SetQuickAck(true)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
@@ -902,7 +908,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
var de tcpip.TCPDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
- e.SetSockOptBool(tcpip.DelayOption, true)
+ e.ops.SetDelayOption(true)
}
var tcpLT tcpip.TCPLingerTimeoutOption
@@ -915,6 +921,8 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.maxSynRetries = uint8(synRetries)
}
+ s.TransportProtocolOption(ProtocolNumber, &e.tcpRecovery)
+
if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -922,6 +930,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
return e
}
@@ -1043,7 +1052,8 @@ func (e *endpoint) Close() {
return
}
- if e.linger.Enabled && e.linger.Timeout == 0 {
+ linger := e.SocketOptions().GetLinger()
+ if linger.Enabled && linger.Timeout == 0 {
s := e.EndpointState()
isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
if isResetState {
@@ -1146,6 +1156,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
e.closePendingAcceptableConnectionsLocked()
+ e.keepalive.timer.cleanup()
e.workerCleanup = false
@@ -1162,7 +1173,11 @@ func (e *endpoint) cleanupLocked() {
e.boundPortFlags = ports.Flags{}
e.boundDest = tcpip.FullAddress{}
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
+
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -1272,11 +1287,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-func (e *endpoint) LastError() *tcpip.Error {
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) hardErrorLocked() *tcpip.Error {
+ err := e.hardError
+ e.hardError = nil
+ return err
+}
+
+// Preconditions: e.mu must be held to call this function.
+func (e *endpoint) lastErrorLocked() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1284,8 +1308,88 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
-// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// LastError implements tcpip.Endpoint.LastError.
+func (e *endpoint) LastError() *tcpip.Error {
+ e.LockUser()
+ defer e.UnlockUser()
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return e.lastErrorLocked()
+}
+
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.LockUser()
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+ e.UnlockUser()
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ e.rcvReadMu.Lock()
+ defer e.rcvReadMu.Unlock()
+
+ // N.B. Here we get a range of segments to be processed. It is safe to not
+ // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // can remove segments from the list through commitRead().
+ first, last, serr := e.startRead()
+ if serr != nil {
+ if serr == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
+ return tcpip.ReadResult{}, serr
+ }
+
+ var err error
+ done := 0
+ s := first
+ for s != nil && done < count {
+ var n int
+ n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ // Book keeping first then error handling.
+
+ done += n
+
+ if opts.Peek {
+ // For peek, we use the (first, last) range of segment returned from
+ // startRead. We don't consume the receive buffer, so commitRead should
+ // not be called.
+ //
+ // N.B. It is important to use `last` to determine the last segment, since
+ // appending can happen while we process, and will lead to data race.
+ if s == last {
+ break
+ }
+ s = s.Next()
+ } else {
+ // N.B. commitRead() conveniently returns the next segment to read, after
+ // removing the data/segment that is read.
+ s = e.commitRead(n)
+ }
+
+ if err != nil {
+ break
+ }
+ }
+
+ // If something is read, we must report it. Report error when nothing is read.
+ if done == 0 && err != nil {
+ return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ }
+ return tcpip.ReadResult{
+ Count: done,
+ Total: done,
+ }, nil
+}
+
+// startRead checks that endpoint is in a readable state, and return the
+// inclusive range of segments that can be read.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1294,7 +1398,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1302,59 +1406,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
- e.rcvListMu.Unlock()
- he := e.HardError
if s == StateError {
- return buffer.View{}, tcpip.ControlMessages{}, he
+ if err := e.hardErrorLocked(); err != nil {
+ return nil, nil, err
+ }
+ return nil, nil, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
- }
-
- v, err := e.readLocked()
- e.rcvListMu.Unlock()
-
- if err == tcpip.ErrClosedForReceive {
- e.stats.ReadErrors.ReadClosed.Increment()
+ return nil, nil, tcpip.ErrNotConnected
}
- return v, tcpip.ControlMessages{}, err
-}
-func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return buffer.View{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
- return buffer.View{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
- s := e.rcvList.Front()
- views := s.data.Views()
- v := views[s.viewToDeliver]
- s.viewToDeliver++
+ return e.rcvList.Front(), e.rcvList.Back(), nil
+}
- var delta int
- if s.viewToDeliver >= len(views) {
+// commitRead commits a read of done bytes and returns the next non-empty
+// segment to read. Data read from the segment must have also been removed from
+// the segment in order for this method to work correctly.
+//
+// It is performance critical to call commitRead frequently when servicing a big
+// Read request, so TCP can make progress timely. Right now, it is designed to
+// do this per segment read, hence this method conveniently returns the next
+// segment to read while holding the lock.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) commitRead(done int) *segment {
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ memDelta := 0
+ s := e.rcvList.Front()
+ for s != nil && s.data.Size() == 0 {
e.rcvList.Remove(s)
- // We only free up receive buffer space when the segment is released as the
- // segment is still holding on to the views even though some views have been
- // read out to the user.
- delta = s.segMemSize()
+ // Memory is only considered released when the whole segment has been
+ // read.
+ memDelta += s.segMemSize()
s.decRef()
+ s = e.rcvList.Front()
}
+ e.rcvBufUsed -= done
- e.rcvBufUsed -= len(v)
- // If the window was small before this read and if the read freed up
- // enough buffer space, to either fit an aMSS or half a receive buffer
- // (whichever smaller), then notify the protocol goroutine to send a
- // window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ if memDelta > 0 {
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
}
- return v, nil
+ return e.rcvList.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
@@ -1363,9 +1477,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+ // The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
- return 0, e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return 0, err
+ }
+ return 0, tcpip.ErrClosedForSend
case !s.connecting() && !s.connected():
return 0, tcpip.ErrClosedForSend
case s.connecting():
@@ -1426,104 +1544,38 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, perr
}
- queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
- // Add data to the send queue.
- s := newOutgoingSegment(e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
-
- // Do the work inline.
- e.handleWrite()
- e.UnlockUser()
- return int64(len(v)), nil, nil
- }
-
- if opts.Atomic {
- // Locks released in queueAndSend()
- return queueAndSend()
- }
-
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- e.LockUser()
- e.sndBufMu.Lock()
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.UnlockUser()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
-
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
-
- // Locks released in queueAndSend()
- return queueAndSend()
-}
-
-// Peek reads data without consuming it from the endpoint.
-//
-// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- e.LockUser()
- defer e.UnlockUser()
-
- // The endpoint can be read if it's connected, or if it's already closed
- // but has some pending unread data.
- if s := e.EndpointState(); !s.connected() && s != StateClose {
- if s == StateError {
- return 0, tcpip.ControlMessages{}, e.HardError
+ if !opts.Atomic {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ e.LockUser()
+ e.sndBufMu.Lock()
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
}
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
- }
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
-
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
- e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
- return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
}
- // Make a copy of vec so we can modify the slide headers.
- vec = append([][]byte(nil), vec...)
-
- var num int64
- for s := e.rcvList.Front(); s != nil; s = s.Next() {
- views := s.data.Views()
-
- for i := s.viewToDeliver; i < len(views); i++ {
- v := views[i]
-
- for len(v) > 0 {
- if len(vec) == 0 {
- return num, tcpip.ControlMessages{}, nil
- }
- if len(vec[0]) == 0 {
- vec = vec[1:]
- continue
- }
-
- n := copy(vec[0], v)
- v = v[n:]
- vec[0] = vec[0][n:]
- num += int64(n)
- }
- }
- }
+ // Add data to the send queue.
+ s := newOutgoingSegment(e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
- return num, tcpip.ControlMessages{}, nil
+ // Do the work inline.
+ e.handleWrite()
+ e.UnlockUser()
+ return int64(len(v)), nil, nil
}
// selectWindowLocked returns the new window without checking for shrinking or scaling
@@ -1595,77 +1647,39 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo
return false, false
}
-// SetSockOptBool sets a socket option.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
-
- case tcpip.BroadcastOption:
- e.LockUser()
- e.broadcast = v
- e.UnlockUser()
-
- case tcpip.CorkOption:
- e.LockUser()
- if !v {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- e.UnlockUser()
-
- case tcpip.DelayOption:
- if v {
- atomic.StoreUint32(&e.delay, 1)
- } else {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- }
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- e.keepalive.enabled = v
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
-
- case tcpip.QuickAckOption:
- o := uint32(1)
- if v {
- o = 0
- }
- atomic.StoreUint32(&e.slowAck, o)
-
- case tcpip.ReuseAddressOption:
- e.LockUser()
- e.portFlags.TupleOnly = v
- e.UnlockUser()
-
- case tcpip.ReusePortOption:
- e.LockUser()
- e.portFlags.LoadBalanced = v
- e.UnlockUser()
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.LockUser()
+ e.portFlags.TupleOnly = v
+ e.UnlockUser()
+}
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.LockUser()
+ e.portFlags.LoadBalanced = v
+ e.UnlockUser()
+}
- // We only allow this to be set when we're in the initial state.
- if e.EndpointState() != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
+// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet.
+func (e *endpoint) OnKeepAliveSet(v bool) {
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+}
- e.LockUser()
- e.v6only = v
- e.UnlockUser()
+// OnDelayOptionSet implements tcpip.SocketOptionsHandler.OnDelayOptionSet.
+func (e *endpoint) OnDelayOptionSet(v bool) {
+ if !v {
+ // Handle delayed data.
+ e.sndWaker.Assert()
}
+}
- return nil
+// OnCorkOptionSet implements tcpip.SocketOptionsHandler.OnCorkOptionSet.
+func (e *endpoint) OnCorkOptionSet(v bool) {
+ if !v {
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ }
}
// SetSockOptInt sets a socket option.
@@ -1825,18 +1839,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.LockUser()
- e.bindToDevice = id
- e.UnlockUser()
-
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(*v)
@@ -1849,9 +1858,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
-
case *tcpip.TCPUserTimeoutOption:
e.LockUser()
e.userTimeout = time.Duration(*v)
@@ -1920,11 +1926,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
- case *tcpip.LingerOption:
- e.LockUser()
- e.linger = *v
- e.UnlockUser()
-
default:
return nil
}
@@ -1947,72 +1948,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.BroadcastOption:
- e.LockUser()
- v := e.broadcast
- e.UnlockUser()
- return v, nil
-
- case tcpip.CorkOption:
- return atomic.LoadUint32(&e.cork) != 0, nil
-
- case tcpip.DelayOption:
- return atomic.LoadUint32(&e.delay) != 0, nil
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- v := e.keepalive.enabled
- e.keepalive.Unlock()
-
- return v, nil
-
- case tcpip.QuickAckOption:
- v := atomic.LoadUint32(&e.slowAck) == 0
- return v, nil
-
- case tcpip.ReuseAddressOption:
- e.LockUser()
- v := e.portFlags.TupleOnly
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.LockUser()
- v := e.portFlags.LoadBalanced
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.LockUser()
- v := e.v6only
- e.UnlockUser()
-
- return v, nil
-
- case tcpip.MulticastLoopOption:
- return true, nil
-
- case tcpip.AcceptConnOption:
- e.LockUser()
- defer e.UnlockUser()
-
- return e.EndpointState() == StateListen, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -2091,11 +2026,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case *tcpip.BindToDeviceOption:
- e.LockUser()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.UnlockUser()
-
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
e.LockUser()
@@ -2123,10 +2053,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.TCPUserTimeoutOption(e.userTimeout)
e.UnlockUser()
- case *tcpip.OutOfBandInlineOption:
- // We don't currently support disabling this option.
- *o = 1
-
case *tcpip.CongestionControlOption:
e.LockUser()
*o = e.cc
@@ -2155,11 +2081,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
Port: port,
}
- case *tcpip.LingerOption:
- e.LockUser()
- *o = e.linger
- e.UnlockUser()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2169,7 +2090,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -2185,6 +2106,8 @@ func (*endpoint) Disconnect() *tcpip.Error {
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
err := e.connect(addr, true, true)
if err != nil && !err.IgnoreStats() {
+ // Connect failed. Let's wake up any waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
}
@@ -2244,7 +2167,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnecting
case StateError:
- return e.HardError
+ if err := e.hardErrorLocked(); err != nil {
+ return err
+ }
+ return tcpip.ErrConnectionAborted
default:
return tcpip.ErrInvalidEndpointState
@@ -2302,11 +2228,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
}
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
if err != tcpip.ErrPortInUse || !reuse {
return false, nil
}
@@ -2344,15 +2271,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
return false, nil
}
}
id := e.ID
id.LocalPort = p
- if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
if err == tcpip.ErrPortInUse {
return false, nil
}
@@ -2363,7 +2290,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// the selected port.
e.ID = id
e.isPortReserved = true
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
e.boundDest = addr
return true, nil
@@ -2374,7 +2301,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.isRegistered = true
e.setEndpointState(StateConnecting)
- e.route = r.Clone()
+ r.Acquire()
+ e.route = r
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -2397,14 +2325,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
if run {
- e.workerRunning = true
- e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ if err := e.startMainLoop(handshake); err != nil {
+ return err
+ }
}
return tcpip.ErrConnectStarted
}
+// startMainLoop sends the initial SYN and starts the main loop for the
+// endpoint.
+func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
+ preloop := func() *tcpip.Error {
+ if handshake {
+ h := e.newHandshake()
+ e.setEndpointState(StateSynSent)
+ if err := h.start(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.setEndpointState(StateError)
+ e.hardError = err
+
+ // Call cleanupLocked to free up any reservations.
+ e.cleanupLocked()
+ return err
+ }
+ }
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ return nil
+ }
+
+ if e.route.IsResolutionRequired() {
+ // If the endpoint is closed between releasing e.mu and the goroutine below
+ // acquiring it, make sure that cleanup is deferred to the new goroutine.
+ e.workerRunning = true
+
+ // Sending the initial SYN may block due to route resolution; do it in a
+ // separate goroutine to avoid blocking the syscall goroutine.
+ go func() { // S/R-SAFE: will be drained before save.
+ e.mu.Lock()
+ if err := preloop(); err != nil {
+ e.workerRunning = false
+ e.mu.Unlock()
+ return
+ }
+ e.mu.Unlock()
+ _ = e.protocolMainLoop(handshake, nil)
+ }()
+ return nil
+ }
+
+ // No route resolution is required, so we can send the initial SYN here without
+ // blocking. This will hopefully reduce overall latency by overlapping time
+ // spent waiting for a SYN-ACK and time spent spinning up a new goroutine
+ // for the main loop.
+ if err := preloop(); err != nil {
+ return err
+ }
+ e.workerRunning = true
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ return nil
+}
+
// ConnectEndpoint is not supported.
func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
@@ -2642,7 +2626,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// v6only set to false.
if netProto == header.IPv6ProtocolNumber {
stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
- alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4
if alsoBindToV4 {
netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
@@ -2659,7 +2643,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
id := e.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2670,7 +2655,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// demuxer. Further connected endpoints always have a remote
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
- if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
return false
}
return true
@@ -2679,7 +2664,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- e.boundBindToDevice = e.bindToDevice
+ e.boundBindToDevice = bindToDevice
e.boundPortFlags = e.portFlags
// TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
e.boundNICID = nic
@@ -2727,7 +2712,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
- // land at the Dispatcher which then can either delivery using the
+ // land at the Dispatcher which then can either deliver using the
// worker go routine or directly do the invoke the tcp processing inline
// based on the state of the endpoint.
}
@@ -2743,8 +2728,43 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
return true
}
+func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ // Linux passes the payload with the TCP header. We don't know if the TCP
+ // header even exists, it may not for fragmented packets.
+ Payload: pkt.Data.ToView(),
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.notifyProtocolGoroutine(notifyError)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2757,16 +2777,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.notifyProtocolGoroutine(notifyMTUChanged)
case stack.ControlNoRoute:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNoRoute
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
case stack.ControlNetworkUnreachable:
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrNetworkUnreachable
- e.lastErrorMu.Unlock()
- e.notifyProtocolGoroutine(notifyError)
+ e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
}
@@ -3024,6 +3038,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
Ssthresh: e.snd.sndSsthresh,
SndCAAckCount: e.snd.sndCAAckCount,
Outstanding: e.snd.outstanding,
+ SackedOut: e.snd.sackedOut,
SndWnd: e.snd.sndWnd,
SndUna: e.snd.sndUna,
SndNxt: e.snd.sndNxt,
@@ -3054,13 +3069,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
}
}
- rc := e.snd.rc
+ rc := &e.snd.rc
s.Sender.RACKState = stack.TCPRACKState{
XmitTime: rc.xmitTime,
EndSequence: rc.endSequence,
FACK: rc.fack,
RTT: rc.rtt,
Reord: rc.reorderSeen,
+ DSACKSeen: rc.dsackSeen,
}
return s
}
@@ -3105,7 +3121,7 @@ func (e *endpoint) State() uint32 {
func (e *endpoint) Info() tcpip.EndpointInfo {
e.LockUser()
// Make a copy of the endpoint info.
- ret := e.EndpointInfo
+ ret := e.TransportEndpointInfo
e.UnlockUser()
return &ret
}
@@ -3130,3 +3146,8 @@ func (e *endpoint) Wait() {
<-notifyCh
}
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 2bcc5e1c2..ba67176b5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() {
// Condition variables and mutexs are not S/R'ed so reinitialize
// acceptCond with e.acceptMu.
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -320,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
}
// saveHardError is invoked by stateify.
-func (e *EndpointInfo) saveHardError() string {
- if e.HardError == nil {
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
return ""
}
- return e.HardError.String()
+ return e.hardError.String()
}
// loadHardError is invoked by stateify.
-func (e *EndpointInfo) loadHardError(s string) {
+func (e *endpoint) loadHardError(s string) {
if s == "" {
return
}
- e.HardError = tcpip.StringToError(s)
+ e.hardError = tcpip.StringToError(s)
}
// saveMeasureTime is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 0664789da..596178625 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2329aca4b..c9e194f82 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
ttl = route.DefaultTTL()
}
- return sendTCP(&route, tcpFields{
+ return sendTCP(route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
@@ -405,7 +405,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E
case *tcpip.TCPRecovery:
p.mu.RLock()
- *v = tcpip.TCPRecovery(p.recovery)
+ *v = p.recovery
p.mu.RUnlock()
return nil
@@ -543,7 +543,8 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
- recovery: tcpip.TCPRACKLossDetection,
+ // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection.
+ recovery: 0,
}
p.dispatcher.init(runtime.GOMAXPROCS(0))
return &p
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index d312b1b8b..5a4ee70f5 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -17,9 +17,18 @@ package tcp
import (
"time"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
+// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as
+// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.
+// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is
+// 1, PTO is inflated by WCDelAckT time to compensate for a potential long
+// delayed ACK timer at the receiver.
+const wcDelayedACKTimeout = 200 * time.Millisecond
+
// RACK is a loss detection algorithm used in TCP to detect packet loss and
// reordering using transmission timestamp of the packets instead of packet or
// sequence counts. To use RACK, SACK should be enabled on the connection.
@@ -29,12 +38,12 @@ import (
//
// +stateify savable
type rackControl struct {
+ // dsackSeen indicates if the connection has seen a DSACK.
+ dsackSeen bool
+
// endSequence is the ending TCP sequence number of rackControl.seg.
endSequence seqnum.Value
- // dsack indicates if the connection has seen a DSACK.
- dsack bool
-
// fack is the highest selectively or cumulatively acknowledged
// sequence.
fack seqnum.Value
@@ -54,6 +63,23 @@ type rackControl struct {
// xmitTime is the latest transmission timestamp of rackControl.seg.
xmitTime time.Time `state:".(unixTime)"`
+
+ // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
+ probeTimer timer `state:"nosave"`
+ probeWaker sleep.Waker `state:"nosave"`
+
+ // tlpRxtOut indicates whether there is an unacknowledged
+ // TLP retransmission.
+ tlpRxtOut bool
+
+ // tlpHighRxt the value of sender.sndNxt at the time of sending
+ // a TLP retransmission.
+ tlpHighRxt seqnum.Value
+}
+
+// init initializes RACK specific fields.
+func (rc *rackControl) init() {
+ rc.probeTimer.init(&rc.probeWaker)
}
// update will update the RACK related fields when an ACK has been received.
@@ -122,3 +148,103 @@ func (rc *rackControl) detectReorder(seg *segment) {
rc.reorderSeen = true
}
}
+
+// setDSACKSeen updates rack control if duplicate SACK is seen by the connection.
+func (rc *rackControl) setDSACKSeen() {
+ rc.dsackSeen = true
+}
+
+// shouldSchedulePTO dictates whether we should schedule a PTO or not.
+// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.1.
+func (s *sender) shouldSchedulePTO() bool {
+ // Schedule PTO only if RACK loss detection is enabled.
+ return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 &&
+ // The connection supports SACK.
+ s.ep.sackPermitted &&
+ // The connection is not in loss recovery.
+ (s.state != RTORecovery && s.state != SACKRecovery) &&
+ // The connection has no SACKed sequences in the SACK scoreboard.
+ s.ep.scoreboard.Sacked() == 0
+}
+
+// schedulePTO schedules the probe timeout as defined in
+// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.1.
+func (s *sender) schedulePTO() {
+ pto := time.Second
+ s.rtt.Lock()
+ if s.rtt.srttInited && s.rtt.srtt > 0 {
+ pto = s.rtt.srtt * 2
+ if s.outstanding == 1 {
+ pto += wcDelayedACKTimeout
+ }
+ }
+ s.rtt.Unlock()
+
+ now := time.Now()
+ if s.resendTimer.enabled() {
+ if now.Add(pto).After(s.resendTimer.target) {
+ pto = s.resendTimer.target.Sub(now)
+ }
+ s.resendTimer.disable()
+ }
+
+ s.rc.probeTimer.enable(pto)
+}
+
+// probeTimerExpired is the same as TLP_send_probe() as defined in
+// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2.
+func (s *sender) probeTimerExpired() *tcpip.Error {
+ if !s.rc.probeTimer.checkExpiration() {
+ return nil
+ }
+ // TODO(gvisor.dev/issue/5084): Implement this pseudo algorithm.
+ // If an unsent segment exists AND
+ // the receive window allows new data to be sent:
+ // Transmit the lowest-sequence unsent segment of up to SMSS
+ // Increment FlightSize by the size of the newly-sent segment
+ // Else if TLPRxtOut is not set:
+ // Retransmit the highest-sequence segment sent so far
+ // TLPRxtOut = true
+ // TLPHighRxt = SND.NXT
+ // The cwnd remains unchanged
+ // If FlightSize != 0:
+ // Arm RTO timer only.
+ return nil
+}
+
+// detectTLPRecovery detects if recovery was accomplished by the loss probes
+// and updates TLP state accordingly.
+// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3.
+func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
+ if !(s.ep.sackPermitted && s.rc.tlpRxtOut) {
+ return
+ }
+
+ // Step 1.
+ if s.isDupAck(rcvdSeg) && ack == s.rc.tlpHighRxt {
+ var sbAboveTLPHighRxt bool
+ for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
+ if s.rc.tlpHighRxt.LessThan(sb.End) {
+ sbAboveTLPHighRxt = true
+ break
+ }
+ }
+ if !sbAboveTLPHighRxt {
+ // TLP episode is complete.
+ s.rc.tlpRxtOut = false
+ }
+ }
+
+ if s.rc.tlpRxtOut && s.rc.tlpHighRxt.LessThanEq(ack) {
+ // TLP episode is complete.
+ s.rc.tlpRxtOut = false
+ if !checkDSACK(rcvdSeg) {
+ // Step 2. Either the original packet or the retransmission (in the
+ // form of a probe) was lost. Invoke a congestion control response
+ // equivalent to fast recovery.
+ s.cc.HandleNDupAcks()
+ s.enterRecovery()
+ s.leaveRecovery()
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go
index c9dc7e773..76cad0831 100644
--- a/pkg/tcpip/transport/tcp/rack_state.go
+++ b/pkg/tcpip/transport/tcp/rack_state.go
@@ -27,3 +27,8 @@ func (rc *rackControl) saveXmitTime() unixTime {
func (rc *rackControl) loadXmitTime(unix unixTime) {
rc.xmitTime = time.Unix(unix.second, unix.nano)
}
+
+// afterLoad is invoked by stateify.
+func (rc *rackControl) afterLoad() {
+ rc.probeTimer.init(&rc.probeWaker)
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 8e0b7c843..405a6dce7 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -16,6 +16,7 @@ package tcp
import (
"container/heap"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -48,6 +49,10 @@ type receiver struct {
rcvWndScale uint8
+ // prevBufused is the snapshot of endpoint rcvBufUsed taken when we
+ // advertise a receive window.
+ prevBufUsed int
+
closed bool
// pendingRcvdSegments is bounded by the receive buffer size of the
@@ -80,9 +85,9 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// outgoing packets, we should use what we have advertised for acceptability
// test.
scaledWindowSize := r.rcvWnd >> r.rcvWndScale
- if scaledWindowSize > 0xffff {
+ if scaledWindowSize > math.MaxUint16 {
// This is what we actually put in the Window field.
- scaledWindowSize = 0xffff
+ scaledWindowSize = math.MaxUint16
}
advertisedWindowSize := scaledWindowSize << r.rcvWndScale
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
@@ -106,6 +111,34 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) {
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
newWnd := r.ep.selectWindow()
curWnd := r.currentWindow()
+ unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt))
+ bufUsed := r.ep.receiveBufferUsed()
+
+ // Grow the right edge of the window only for payloads larger than the
+ // the segment overhead OR if the application is actively consuming data.
+ //
+ // Avoiding growing the right edge otherwise, addresses a situation below:
+ // An application has been slow in reading data and we have burst of
+ // incoming segments lengths < segment overhead. Here, our available free
+ // memory would reduce drastically when compared to the advertised receive
+ // window.
+ //
+ // For example: With incoming 512 bytes segments, segment overhead of
+ // 552 bytes (at the time of writing this comment), with receive window
+ // starting from 1MB and with rcvAdvWndScale being 1, buffer would reach 0
+ // when the curWnd is still 19436 bytes, because for every incoming segment
+ // newWnd would reduce by (552+512) >> rcvAdvWndScale (current value 1),
+ // while curWnd would reduce by 512 bytes.
+ // Such a situation causes us to keep tail dropping the incoming segments
+ // and never advertise zero receive window to the peer.
+ //
+ // Linux does a similar check for minimal sk_buff size (128):
+ // https://github.com/torvalds/linux/blob/d5beb3140f91b1c8a3d41b14d729aefa4dcc58bc/net/ipv4/tcp_input.c#L783
+ //
+ // Also, if the application is reading the data, we keep growing the right
+ // edge, as we are still advertising a window that we think can be serviced.
+ toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed
+
// Update rcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
@@ -115,7 +148,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// rcvWUP rcvNxt rcvAcc new rcvAcc
// <=====curWnd ===>
// <========= newWnd > curWnd ========= >
- if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow {
// If the new window moves the right edge, then update rcvAcc.
r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
} else {
@@ -130,11 +163,22 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// receiver's estimated RTT.
r.rcvWnd = newWnd
r.rcvWUP = r.rcvNxt
+ r.prevBufUsed = bufUsed
scaledWnd := r.rcvWnd >> r.rcvWndScale
if scaledWnd == 0 {
// Increment a metric if we are advertising an actual zero window.
r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
}
+
+ // If we started off with a window larger than what can he held in
+ // the 16bit window field, we ceil the value to the max value.
+ if scaledWnd > math.MaxUint16 {
+ scaledWnd = seqnum.Size(math.MaxUint16)
+
+ // Ensure that the stashed receive window always reflects what
+ // is being advertised.
+ r.rcvWnd = scaledWnd << r.rcvWndScale
+ }
return r.rcvNxt, scaledWnd
}
diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go
new file mode 100644
index 000000000..2aa708e97
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/reno_recovery.go
@@ -0,0 +1,67 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoRecovery stores the variables related to TCP Reno loss recovery
+// algorithm.
+//
+// +stateify savable
+type renoRecovery struct {
+ s *sender
+}
+
+func newRenoRecovery(s *sender) *renoRecovery {
+ return &renoRecovery{s: s}
+}
+
+func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
+ ack := rcvdSeg.ackNumber
+ snd := rr.s
+
+ // We are in fast recovery mode. Ignore the ack if it's out of range.
+ if !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ return
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window {
+ return
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if !fastRetransmit && ack == snd.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if snd.sndCwnd < snd.fr.maxCwnd {
+ snd.sndCwnd++
+ }
+ return
+ }
+
+ // A partial ack was received. Retransmit this packet and remember it
+ // so that we don't retransmit it again.
+ //
+ // We don't inflate the window because we're putting the same packet
+ // back onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ snd.fr.first = ack
+ snd.dupAckCount = 0
+ snd.resendSegment()
+}
diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go
new file mode 100644
index 000000000..7e813fa96
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_recovery.go
@@ -0,0 +1,120 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+
+// sackRecovery stores the variables related to TCP SACK loss recovery
+// algorithm.
+//
+// +stateify savable
+type sackRecovery struct {
+ s *sender
+}
+
+func newSACKRecovery(s *sender) *sackRecovery {
+ return &sackRecovery{s: s}
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ snd := sr.s
+ snd.SetPipe()
+
+ if smss := int(snd.ep.scoreboard.SMSS()); limit > smss {
+ // Cap segment size limit to s.smss as SACK recovery requires
+ // that all retransmissions or new segments send during recovery
+ // be of <= SMSS.
+ limit = smss
+ }
+
+ nextSegHint := snd.writeList.Front()
+ for snd.outstanding < snd.sndCwnd {
+ var nextSeg *segment
+ var rescueRtx bool
+ nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint)
+ if nextSeg == nil {
+ return dataSent
+ }
+ if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
+ // New data being sent.
+
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ //
+ // We pass s.smss as the limit as the Step 2) requires that
+ // new data sent should be of size s.smss or less.
+ if sent := snd.maybeSendSegment(nextSeg, limit, end); !sent {
+ return dataSent
+ }
+ dataSent = true
+ snd.outstanding++
+ snd.writeNext = nextSeg.Next()
+ continue
+ }
+
+ // Now handle the retransmission case where we matched either step 1,3 or 4
+ // of the NextSeg algorithm.
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ snd.outstanding++
+ dataSent = true
+ snd.sendSegment(nextSeg)
+
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if rescueRtx {
+ // We do the last part of rule (4) of NextSeg here to update
+ // RescueRxt as until this point we don't know if we are going
+ // to use the rescue transmission.
+ snd.fr.rescueRxt = snd.fr.last
+ } else {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ snd.fr.highRxt = segEnd - 1
+ }
+ }
+ return dataSent
+}
+
+func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) {
+ snd := sr.s
+ if fastRetransmit {
+ snd.resendSegment()
+ }
+
+ // We are in fast recovery mode. Ignore the ack if it's out of range.
+ if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) {
+ return
+ }
+
+ // RFC 6675 recovery algorithm step C 1-5.
+ end := snd.sndUna.Add(snd.sndWnd)
+ dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end)
+ snd.postXmit(dataSent)
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 2091989cc..c5a6d2fba 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -37,7 +37,7 @@ const (
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
-// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+// segment is mostly immutable, the only field allowed to change is data.
//
// +stateify savable
type segment struct {
@@ -60,10 +60,7 @@ type segment struct {
hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
- // viewToDeliver keeps track of the next View that should be
- // delivered by the Read endpoint.
- viewToDeliver int
+ views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
flags uint8
@@ -84,6 +81,9 @@ type segment struct {
// acked indicates if the segment has already been SACKed.
acked bool
+
+ // dataMemSize is the memory used by data initially.
+ dataMemSize int
}
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
@@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
+ s.dataMemSize = s.data.Size()
return s
}
@@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
}
+ s.dataMemSize = s.data.Size()
return s
}
@@ -127,12 +129,12 @@ func (s *segment) clone() *segment {
netProto: s.netProto,
nicID: s.nicID,
remoteLinkAddr: s.remoteLinkAddr,
- viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
ep: s.ep,
qFlags: s.qFlags,
+ dataMemSize: s.dataMemSize,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -204,7 +206,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return segSize + s.data.Size()
+ return SegSize + s.dataMemSize
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7dc2741a6..7422d8c02 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -24,16 +24,11 @@ import (
func (s *segment) saveData() buffer.VectorisedView {
// We cannot save s.data directly as s.data.views may alias to s.views,
// which is not allowed by state framework (in-struct pointer).
- v := make([]buffer.View, len(s.data.Views()))
- // For views already delivered, we cannot save them directly as they may
- // have already been sliced and saved elsewhere (e.g., readViews).
- for i := 0; i < s.viewToDeliver; i++ {
- v[i] = append([]byte(nil), s.data.Views()[i]...)
+ vs := make([]buffer.View, len(s.data.Views()))
+ for i, v := range s.data.Views() {
+ vs[i] = v
}
- for i := s.viewToDeliver; i < len(v); i++ {
- v[i] = s.data.Views()[i]
- }
- return buffer.NewVectorisedView(s.data.Size(), v)
+ return buffer.NewVectorisedView(s.data.Size(), vs)
}
// loadData is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
index 0ab7b8f56..392ff0859 100644
--- a/pkg/tcpip/transport/tcp/segment_unsafe.go
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -19,5 +19,6 @@ import (
)
const (
- segSize = int(unsafe.Sizeof(segment{}))
+ // SegSize is the minimal size of the segment overhead.
+ SegSize = int(unsafe.Sizeof(segment{}))
)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index ab5fa4fb7..079d90848 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -18,7 +18,6 @@ import (
"fmt"
"math"
"sort"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -92,6 +91,17 @@ type congestionControl interface {
PostRecovery()
}
+// lossRecovery is an interface that must be implemented by any supported
+// loss recovery algorithm.
+type lossRecovery interface {
+ // DoRecovery is invoked when loss is detected and segments need
+ // to be retransmitted. The cumulative or selective ACK is passed along
+ // with the flag which identifies whether the connection entered fast
+ // retransmit with this ACK and to retransmit the first unacknowledged
+ // segment.
+ DoRecovery(rcvdSeg *segment, fastRetransmit bool)
+}
+
// sender holds the state necessary to send TCP segments.
//
// +stateify savable
@@ -108,6 +118,9 @@ type sender struct {
// fr holds state related to fast recovery.
fr fastRecovery
+ // lr is the loss recovery algorithm used by the sender.
+ lr lossRecovery
+
// sndCwnd is the congestion window, in packets.
sndCwnd int
@@ -124,6 +137,9 @@ type sender struct {
// that have been sent but not yet acknowledged.
outstanding int
+ // sackedOut is the number of packets which are selectively acked.
+ sackedOut int
+
// sndWnd is the send window size.
sndWnd seqnum.Size
@@ -270,12 +286,16 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
gso: ep.gso != nil,
}
+ s.rc.init()
+
if s.gso {
s.ep.gso.MSS = uint16(maxPayloadSize)
}
s.cc = s.initCongestionControl(ep.cc)
+ s.lr = s.initLossRecovery()
+
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
if sndWndScale > 0 {
@@ -330,6 +350,14 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon
}
}
+// initLossRecovery initiates the loss recovery algorithm for the sender.
+func (s *sender) initLossRecovery() lossRecovery {
+ if s.ep.sackPermitted {
+ return newSACKRecovery(s)
+ }
+ return newRenoRecovery(s)
+}
+
// updateMaxPayloadSize updates the maximum payload size based on the given
// MTU. If this is in response to "packet too big" control packets (indicated
// by the count argument), it also reduces the number of outstanding packets and
@@ -349,6 +377,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
m = 1
}
+ oldMSS := s.maxPayloadSize
s.maxPayloadSize = m
if s.gso {
s.ep.gso.MSS = uint16(m)
@@ -371,6 +400,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
+ nextSeg := s.writeNext
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
@@ -378,16 +408,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) {
break
}
- if seg.data.Size() > m {
+ if nextSeg == s.writeNext && seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
- s.writeNext = seg
- break
+ nextSeg = seg
+ }
+
+ if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Update sackedOut for new maximum payload size.
+ s.sackedOut -= s.pCount(seg, oldMSS)
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
+ s.writeNext = nextSeg
s.sendData()
}
@@ -497,6 +533,10 @@ func (s *sender) retransmitTimerExpired() bool {
s.ep.stack.Stats().TCP.Timeouts.Increment()
s.ep.stats.SendErrors.Timeouts.Increment()
+ // Set TLPRxtOut to false according to
+ // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1.
+ s.rc.tlpRxtOut = false
+
// Give up if we've waited more than a minute since the last resend or
// if a user time out is set and we have exceeded the user specified
// timeout since the first retransmission.
@@ -550,7 +590,7 @@ func (s *sender) retransmitTimerExpired() bool {
// We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
// has already been updated when entered fast-recovery.
- s.leaveFastRecovery()
+ s.leaveRecovery()
}
s.state = RTORecovery
@@ -606,13 +646,13 @@ func (s *sender) retransmitTimerExpired() bool {
// pCount returns the number of packets in the segment. Due to GSO, a segment
// can be composed of multiple packets.
-func (s *sender) pCount(seg *segment) int {
+func (s *sender) pCount(seg *segment, maxPayloadSize int) int {
size := seg.data.Size()
if size == 0 {
return 1
}
- return (size-1)/s.maxPayloadSize + 1
+ return (size-1)/maxPayloadSize + 1
}
// splitSeg splits a given segment at the size specified and inserts the
@@ -789,7 +829,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
if !nextTooBig && seg.data.Size() < available {
// Segment is not full.
- if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ if s.outstanding > 0 && s.ep.ops.GetDelayOption() {
// Nagle's algorithm. From Wikipedia:
// Nagle's algorithm works by
// combining a number of small
@@ -808,7 +848,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// send space and MSS.
// TODO(gvisor.dev/issue/2833): Drain the held segments after a
// timeout.
- if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
+ if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() {
return false
}
}
@@ -913,79 +953,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
return true
}
-// handleSACKRecovery implements the loss recovery phase as described in RFC6675
-// section 5, step C.
-func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
- s.SetPipe()
-
- if smss := int(s.ep.scoreboard.SMSS()); limit > smss {
- // Cap segment size limit to s.smss as SACK recovery requires
- // that all retransmissions or new segments send during recovery
- // be of <= SMSS.
- limit = smss
- }
-
- nextSegHint := s.writeList.Front()
- for s.outstanding < s.sndCwnd {
- var nextSeg *segment
- var rescueRtx bool
- nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint)
- if nextSeg == nil {
- return dataSent
- }
- if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
- // New data being sent.
-
- // Step C.3 described below is handled by
- // maybeSendSegment which increments sndNxt when
- // a segment is transmitted.
- //
- // Step C.3 "If any of the data octets sent in
- // (C.1) are above HighData, HighData must be
- // updated to reflect the transmission of
- // previously unsent data."
- //
- // We pass s.smss as the limit as the Step 2) requires that
- // new data sent should be of size s.smss or less.
- if sent := s.maybeSendSegment(nextSeg, limit, end); !sent {
- return dataSent
- }
- dataSent = true
- s.outstanding++
- s.writeNext = nextSeg.Next()
- continue
- }
-
- // Now handle the retransmission case where we matched either step 1,3 or 4
- // of the NextSeg algorithm.
- // RFC 6675, Step C.4.
- //
- // "The estimate of the amount of data outstanding in the network
- // must be updated by incrementing pipe by the number of octets
- // transmitted in (C.1)."
- s.outstanding++
- dataSent = true
- s.sendSegment(nextSeg)
-
- segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
- if rescueRtx {
- // We do the last part of rule (4) of NextSeg here to update
- // RescueRxt as until this point we don't know if we are going
- // to use the rescue transmission.
- s.fr.rescueRxt = s.fr.last
- } else {
- // RFC 6675, Step C.2
- //
- // "If any of the data octets sent in (C.1) are below
- // HighData, HighRxt MUST be set to the highest sequence
- // number of the retransmitted segment unless NextSeg ()
- // rule (4) was invoked for this retransmission."
- s.fr.highRxt = segEnd - 1
- }
- }
- return dataSent
-}
-
func (s *sender) sendZeroWindowProbe() {
ack, win := s.ep.rcv.getSendParams()
s.unackZeroWindowProbes++
@@ -1014,6 +981,30 @@ func (s *sender) disableZeroWindowProbing() {
s.resendTimer.disable()
}
+func (s *sender) postXmit(dataSent bool) {
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
+
+ // If the sender has advertized zero receive window and we have
+ // data to be sent out, start zero window probing to query the
+ // the remote for it's receive window size.
+ if s.writeNext != nil && s.sndWnd == 0 {
+ s.enableZeroWindowProbing()
+ }
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
+}
+
// sendData sends new data segments. It is called when data becomes available or
// when the send window opens up.
func (s *sender) sendData() {
@@ -1034,55 +1025,29 @@ func (s *sender) sendData() {
}
var dataSent bool
-
- // RFC 6675 recovery algorithm step C 1-5.
- if s.fr.active && s.ep.sackPermitted {
- dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
- } else {
- for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
- cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
- if cwndLimit < limit {
- limit = cwndLimit
- }
- if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- // Move writeNext along so that we don't try and scan data that
- // has already been SACKED.
- s.writeNext = seg.Next()
- continue
- }
- if sent := s.maybeSendSegment(seg, limit, end); !sent {
- break
- }
- dataSent = true
- s.outstanding += s.pCount(seg)
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Move writeNext along so that we don't try and scan data that
+ // has already been SACKED.
s.writeNext = seg.Next()
+ continue
}
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding += s.pCount(seg, s.maxPayloadSize)
+ s.writeNext = seg.Next()
}
- if dataSent {
- // We sent data, so we should stop the keepalive timer to ensure
- // that no keepalives are sent while there is pending data.
- s.ep.disableKeepaliveTimer()
- }
-
- // If the sender has advertized zero receive window and we have
- // data to be sent out, start zero window probing to query the
- // the remote for it's receive window size.
- if s.writeNext != nil && s.sndWnd == 0 {
- s.enableZeroWindowProbing()
- }
-
- // Enable the timer if we have pending data and it's not enabled yet.
- if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
- s.resendTimer.enable(s.rto)
- }
- // If we have no more pending data, start the keepalive timer.
- if s.sndUna == s.sndNxt {
- s.ep.resetKeepaliveTimer(false)
- }
+ s.postXmit(dataSent)
}
-func (s *sender) enterFastRecovery() {
+func (s *sender) enterRecovery() {
s.fr.active = true
// Save state to reflect we're now in fast recovery.
//
@@ -1090,6 +1055,7 @@ func (s *sender) enterFastRecovery() {
// We inflate the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
+ s.sackedOut = 0
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
@@ -1098,13 +1064,16 @@ func (s *sender) enterFastRecovery() {
if s.ep.sackPermitted {
s.state = SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
+ // Set TLPRxtOut to false according to
+ // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1.
+ s.rc.tlpRxtOut = false
return
}
s.state = FastRecovery
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
-func (s *sender) leaveFastRecovery() {
+func (s *sender) leaveRecovery() {
s.fr.active = false
s.fr.maxCwnd = 0
s.dupAckCount = 0
@@ -1115,57 +1084,6 @@ func (s *sender) leaveFastRecovery() {
s.cc.PostRecovery()
}
-func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
- ack := seg.ackNumber
- // We are in fast recovery mode. Ignore the ack if it's out of
- // range.
- if !ack.InRange(s.sndUna, s.sndNxt+1) {
- return false
- }
-
- // Leave fast recovery if it acknowledges all the data covered by
- // this fast recovery session.
- if s.fr.last.LessThan(ack) {
- s.leaveFastRecovery()
- return false
- }
-
- if s.ep.sackPermitted {
- // When SACK is enabled we let retransmission be governed by
- // the SACK logic.
- return false
- }
-
- // Don't count this as a duplicate if it is carrying data or
- // updating the window.
- if seg.logicalLen() != 0 || s.sndWnd != seg.window {
- return false
- }
-
- // Inflate the congestion window if we're getting duplicate acks
- // for the packet we retransmitted.
- if ack == s.fr.first {
- // We received a dup, inflate the congestion window by 1 packet
- // if we're not at the max yet. Only inflate the window if
- // regular FastRecovery is in use, RFC6675 does not require
- // inflating cwnd on duplicate ACKs.
- if s.sndCwnd < s.fr.maxCwnd {
- s.sndCwnd++
- }
- return false
- }
-
- // A partial ack was received. Retransmit this packet and
- // remember it so that we don't retransmit it again. We don't
- // inflate the window because we're putting the same packet back
- // onto the wire.
- //
- // N.B. The retransmit timer will be reset by the caller.
- s.fr.first = ack
- s.dupAckCount = 0
- return true
-}
-
// isAssignedSequenceNumber relies on the fact that we only set flags once a
// sequencenumber is assigned and that is only done right before we send the
// segment. As a result any segment that has a non-zero flag has a valid
@@ -1228,26 +1146,15 @@ func (s *sender) SetPipe() {
s.outstanding = pipe
}
-// checkDuplicateAck is called when an ack is received. It manages the state
-// related to duplicate acks and determines if a retransmit is needed according
-// to the rules in RFC 6582 (NewReno).
-func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
- ack := seg.ackNumber
- if s.fr.active {
- return s.handleFastRecovery(seg)
- }
-
- // We're not in fast recovery yet. A segment is considered a duplicate
- // only if it doesn't carry any data and doesn't update the send window,
- // because if it does, it wasn't sent in response to an out-of-order
- // segment. If SACK is enabled then we have an additional check to see
- // if the segment carries new SACK information. If it does then it is
- // considered a duplicate ACK as per RFC6675.
- if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
- if !s.ep.sackPermitted || !seg.hasNewSACKInfo {
- s.dupAckCount = 0
- return false
- }
+// detectLoss is called when an ack is received and returns whether a loss is
+// detected. It manages the state related to duplicate acks and determines if
+// a retransmit is needed according to the rules in RFC 6582 (NewReno).
+func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
+ // We're not in fast recovery yet.
+
+ if !s.isDupAck(seg) {
+ s.dupAckCount = 0
+ return false
}
s.dupAckCount++
@@ -1266,18 +1173,43 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
//
// We only do the check here, the incrementing of last to the highest
- // sequence number transmitted till now is done when enterFastRecovery
+ // sequence number transmitted till now is done when enterRecovery
// is invoked.
if !s.fr.last.LessThan(seg.ackNumber) {
s.dupAckCount = 0
return false
}
s.cc.HandleNDupAcks()
- s.enterFastRecovery()
+ s.enterRecovery()
s.dupAckCount = 0
return true
}
+// isDupAck determines if seg is a duplicate ack as defined in
+// https://tools.ietf.org/html/rfc5681#section-2.
+func (s *sender) isDupAck(seg *segment) bool {
+ // A TCP that utilizes selective acknowledgments (SACKs) [RFC2018, RFC2883]
+ // can leverage the SACK information to determine when an incoming ACK is a
+ // "duplicate" (e.g., if the ACK contains previously unknown SACK
+ // information).
+ if s.ep.sackPermitted && !seg.hasNewSACKInfo {
+ return false
+ }
+
+ // (a) The receiver of the ACK has outstanding data.
+ return s.sndUna != s.sndNxt &&
+ // (b) The incoming acknowledgment carries no data.
+ seg.logicalLen() == 0 &&
+ // (c) The SYN and FIN bits are both off.
+ !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) &&
+ // (d) the ACK number is equal to the greatest acknowledgment received on
+ // the given connection (TCP.UNA from RFC793).
+ seg.ackNumber == s.sndUna &&
+ // (e) the advertised window in the incoming acknowledgment equals the
+ // advertised window in the last incoming acknowledgment.
+ s.sndWnd == seg.window
+}
+
// Iterate the writeList and update RACK for each segment which is newly acked
// either cumulatively or selectively. Loop through the segments which are
// sacked, and update the RACK related variables and check for reordering.
@@ -1285,36 +1217,85 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
func (s *sender) walkSACK(rcvdSeg *segment) {
- if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 {
+ // Look for DSACK block.
+ idx := 0
+ n := len(rcvdSeg.parsedOptions.SACKBlocks)
+ if checkDSACK(rcvdSeg) {
+ s.rc.setDSACKSeen()
+ idx = 1
+ n--
+ }
+
+ if n == 0 {
return
}
// Sort the SACK blocks. The first block is the most recent unacked
// block. The following blocks can be in arbitrary order.
- sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
- copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks)
+ sackBlocks := make([]header.SACKBlock, n)
+ copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks[idx:])
sort.Slice(sackBlocks, func(i, j int) bool {
return sackBlocks[j].Start.LessThan(sackBlocks[i].Start)
})
seg := s.writeList.Front()
for _, sb := range sackBlocks {
- // This check excludes DSACK blocks.
- if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) {
- continue
- }
-
for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
s.rc.detectReorder(seg)
seg.acked = true
+ s.sackedOut += s.pCount(seg, s.maxPayloadSize)
}
seg = seg.Next()
}
}
}
+// checkDSACK checks if a DSACK is reported.
+func checkDSACK(rcvdSeg *segment) bool {
+ n := len(rcvdSeg.parsedOptions.SACKBlocks)
+ if n == 0 {
+ return false
+ }
+
+ sb := rcvdSeg.parsedOptions.SACKBlocks[0]
+ // Check if SACK block is invalid.
+ if sb.End.LessThan(sb.Start) {
+ return false
+ }
+
+ // See: https://tools.ietf.org/html/rfc2883#section-5 DSACK is sent in
+ // at most one SACK block. DSACK is detected in the below two cases:
+ // * If the SACK sequence space is less than this cumulative ACK, it is
+ // an indication that the segment identified by the SACK block has
+ // been received more than once by the receiver.
+ // * If the sequence space in the first SACK block is greater than the
+ // cumulative ACK, then the sender next compares the sequence space
+ // in the first SACK block with the sequence space in the second SACK
+ // block, if there is one. This comparison can determine if the first
+ // SACK block is reporting duplicate data that lies above the
+ // cumulative ACK.
+ if sb.Start.LessThan(rcvdSeg.ackNumber) {
+ return true
+ }
+
+ if n > 1 {
+ sb1 := rcvdSeg.parsedOptions.SACKBlocks[1]
+ if sb1.End.LessThan(sb1.Start) {
+ return false
+ }
+
+ // If the first SACK block is fully covered by second SACK
+ // block, then the first block is a DSACK block.
+ if sb.End.LessThanEq(sb1.End) && sb1.Start.LessThanEq(sb.Start) {
+ return true
+ }
+ }
+
+ return false
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1367,14 +1348,26 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.SetPipe()
}
- // Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(rcvdSeg)
+ ack := rcvdSeg.ackNumber
+ fastRetransmit := false
+ // Do not leave fast recovery, if the ACK is out of range.
+ if s.fr.active {
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) {
+ s.leaveRecovery()
+ }
+ } else {
+ // Detect loss by counting the duplicates and enter recovery.
+ fastRetransmit = s.detectLoss(rcvdSeg)
+ }
+
+ // See if TLP based recovery was successful.
+ s.detectTLPRecovery(ack, rcvdSeg)
// Stash away the current window size.
s.sndWnd = rcvdSeg.window
- ack := rcvdSeg.ackNumber
-
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
@@ -1429,10 +1422,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
datalen := seg.logicalLen()
if datalen > ackLeft {
- prevCount := s.pCount(seg)
+ prevCount := s.pCount(seg, s.maxPayloadSize)
seg.data.TrimFront(int(ackLeft))
seg.sequenceNumber.UpdateForward(ackLeft)
- s.outstanding -= prevCount - s.pCount(seg)
+ s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize)
break
}
@@ -1448,11 +1441,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
s.writeList.Remove(seg)
- // If SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
- s.outstanding -= s.pCount(seg)
+ s.outstanding -= s.pCount(seg, s.maxPayloadSize)
+ } else {
+ s.sackedOut -= s.pCount(seg, s.maxPayloadSize)
}
seg.decRef()
ackLeft -= datalen
@@ -1489,21 +1484,27 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Reset firstRetransmittedSegXmitTime to the zero value.
s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
+ s.rc.probeTimer.disable()
}
}
+
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
- if rtx {
- s.resendSegment()
+ if s.fr.active {
+ s.lr.DoRecovery(rcvdSeg, fastRetransmit)
+ // When SACK is enabled data sending is governed by steps in
+ // RFC 6675 Section 5 recovery steps A-C.
+ // See: https://tools.ietf.org/html/rfc6675#section-5.
+ if s.ep.sackPermitted {
+ return
+ }
}
// Send more data now that some of the pending data has been ack'd, or
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
- s.sendData()
- }
+ s.sendData()
}
// sendSegment sends the specified segment.
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index d3f92b48c..9818ffa0f 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -30,15 +30,17 @@ const (
maxPayload = 10
tsOptionSize = 12
maxTCPOptionSize = 40
+ mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
)
// TestRACKUpdate tests the RACK related fields are updated when an ACK is
// received on a SACK enabled connection.
func TestRACKUpdate(t *testing.T) {
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ c := context.New(t, uint32(mtu))
defer c.Cleanup()
var xmitTime time.Time
+ probeDone := make(chan struct{})
c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
// Validate that the endpoint Sender.RACKState is what we expect.
if state.Sender.RACKState.XmitTime.Before(xmitTime) {
@@ -54,6 +56,7 @@ func TestRACKUpdate(t *testing.T) {
if state.Sender.RACKState.RTT == 0 {
t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
}
+ close(probeDone)
})
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
@@ -73,18 +76,20 @@ func TestRACKUpdate(t *testing.T) {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
- time.Sleep(200 * time.Millisecond)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
}
// TestRACKDetectReorder tests that RACK detects packet reordering.
func TestRACKDetectReorder(t *testing.T) {
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ c := context.New(t, uint32(mtu))
defer c.Cleanup()
- const ackNum = 2
-
var n int
- ch := make(chan struct{})
+ const ackNumToVerify = 2
+ probeDone := make(chan struct{})
c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
gotSeq := state.Sender.RACKState.FACK
wantSeq := state.Sender.SndNxt
@@ -95,7 +100,7 @@ func TestRACKDetectReorder(t *testing.T) {
}
n++
- if n < ackNum {
+ if n < ackNumToVerify {
if state.Sender.RACKState.Reord {
t.Fatalf("RACK reorder detected when there is no reordering")
}
@@ -105,11 +110,11 @@ func TestRACKDetectReorder(t *testing.T) {
if state.Sender.RACKState.Reord == false {
t.Fatalf("RACK reorder detection failed")
}
- close(ch)
+ close(probeDone)
})
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(ackNum * maxPayload)
+ data := buffer.NewView(ackNumToVerify * maxPayload)
for i := range data {
data[i] = byte(i)
}
@@ -120,7 +125,7 @@ func TestRACKDetectReorder(t *testing.T) {
}
bytesRead := 0
- for i := 0; i < ackNum; i++ {
+ for i := 0; i < ackNumToVerify; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
}
@@ -133,5 +138,393 @@ func TestRACKDetectReorder(t *testing.T) {
// Wait for the probe function to finish processing the ACK before the
// test completes.
- <-ch
+ <-probeDone
+}
+
+func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View {
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ data := buffer.NewView(numPackets * maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ for i := 0; i < numPackets; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ return data
+}
+
+const (
+ validDSACKDetected = 1
+ failedToDetectDSACK = 2
+ invalidDSACKDetected = 3
+)
+
+func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) {
+ var n int
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that RACK detects DSACK.
+ n++
+ if n < numACK {
+ if state.Sender.RACKState.DSACKSeen {
+ probeDone <- invalidDSACKDetected
+ }
+ return
+ }
+
+ if !state.Sender.RACKState.DSACKSeen {
+ probeDone <- failedToDetectDSACK
+ return
+ }
+ probeDone <- validDSACKDetected
+ })
+}
+
+// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments.
+// See: https://tools.ietf.org/html/rfc2883#section-4.1.1.
+func TestRACKDetectDSACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 2
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 8
+ data := sendAndReceive(t, c, numPackets)
+
+ // Cumulative ACK for [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Expect retransmission of #6 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Send DSACK block for #6 packet indicating both
+ // initial and retransmitted packet are received and
+ // packets [1-7] are received.
+ start := c.IRS.Add(seqnum.Size(bytesRead))
+ end := start.Add(maxPayload)
+ bytesRead += 2 * maxPayload
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of
+// order segments.
+// See: https://tools.ietf.org/html/rfc2883#section-4.1.2.
+func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 2
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 10
+ data := sendAndReceive(t, c, numPackets)
+
+ // Cumulative ACK for [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Expect retransmission of #6 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Send DSACK block for #6 packet indicating both
+ // initial and retransmitted packet are received and
+ // packets [1-7] are received.
+ start := c.IRS.Add(seqnum.Size(bytesRead))
+ end := start.Add(maxPayload)
+ bytesRead += 2 * maxPayload
+ // Send DSACK block for #6 along with out of
+ // order #9 packet is received.
+ start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload)
+ end1 := start1.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}})
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a
+// duplicate of out of order packet.
+// See: https://tools.ietf.org/html/rfc2883#section-4.1.3
+func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 4
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 10
+ sendAndReceive(t, c, numPackets)
+
+ // ACK [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Send SACK indicating #6 packet is missing and received #7 packet.
+ offset := seqnum.Size(bytesRead + maxPayload)
+ start := c.IRS.Add(1 + offset)
+ end := start.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Send SACK with #6 packet is missing and received [7-8] packets.
+ end = start.Add(2 * maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Consider #8 packet is duplicated on the network and send DSACK.
+ dsackStart := c.IRS.Add(1 + offset + maxPayload)
+ dsackEnd := dsackStart.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment.
+// See: https://tools.ietf.org/html/rfc2883#section-4.2.1.
+func TestRACKDetectDSACKSingleDup(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 4
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 4
+ data := sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-3] packets and received #4 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Expect retransmission of #2 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // ACK for retransmitted #2 packet.
+ bytesRead += maxPayload
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by
+ // sending DSACK block for the subsegment.
+ dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
+ c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous
+// duplicate subsegments covered by the cumulative acknowledgement.
+// See: https://tools.ietf.org/html/rfc2883#section-4.2.2.
+func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 5
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 6
+ data := sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-5] packets and received #6 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(5*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Expect retransmission of #2 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Received delayed #2 packet.
+ bytesRead += maxPayload
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Received delayed #4 packet.
+ start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
+ end1 := start1.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
+
+ // Simulate receiving retransmitted subsegment for #2 packet and delayed #3
+ // packet by sending DSACK block for #2 packet.
+ dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
+ c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not
+// covered by the cumulative acknowledgement.
+// See: https://tools.ietf.org/html/rfc2883#section-4.2.3.
+func TestRACKDetectDSACKDup(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan int)
+ const ackNumToVerify = 5
+ addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
+
+ numPackets := 7
+ data := sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-6] packets and SACK #7 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Received delayed #3 packet.
+ start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
+ end1 := start1.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
+
+ // Expect retransmission of #2 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Consider #2 packet has been dropped and SACK #4 packet.
+ start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
+ end2 := start2.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}})
+
+ // Simulate receiving retransmitted subsegment for #3 packet and delayed #5
+ // packet by sending DSACK block for the subsegment.
+ dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
+ end1 = end1.Add(seqnum.Size(2 * maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}})
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ err := <-probeDone
+ switch err {
+ case failedToDetectDSACK:
+ t.Fatalf("RACK DSACK detection failed")
+ case invalidDSACKDetected:
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+}
+
+// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK
+// is not the first SACK block.
+func TestRACKWithInvalidDSACKBlock(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ const ackNumToVerify = 2
+ var n int
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that RACK does not detect DSACK when DSACK block is
+ // not the first SACK block.
+ n++
+ t.Helper()
+ if state.Sender.RACKState.DSACKSeen {
+ t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
+ }
+
+ if n == ackNumToVerify {
+ close(probeDone)
+ }
+ })
+
+ numPackets := 10
+ data := sendAndReceive(t, c, numPackets)
+
+ // Cumulative ACK for [1-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ bytesRead := 5 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Expect retransmission of #6 packet.
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+
+ // Send DSACK block for #6 packet indicating both
+ // initial and retransmitted packet are received and
+ // packets [1-7] are received.
+ start := c.IRS.Add(seqnum.Size(bytesRead))
+ end := start.Add(maxPayload)
+ bytesRead += 2 * maxPayload
+
+ // Send DSACK block as second block.
+ start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload)
+ end1 := start1.Add(maxPayload)
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index ef7f5719f..faf0c0ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) {
expected++
}
}
+
+// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK
+// is received.
+func TestSACKUpdateSackedOut(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ ackNum := 0
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.SackedOut is what we expect.
+ if state.Sender.SackedOut != 2 && ackNum == 0 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
+ }
+
+ if state.Sender.SackedOut != 0 && ackNum == 1 {
+ t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
+ }
+ if ackNum > 0 {
+ close(probeDone)
+ }
+ ackNum++
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ sendAndReceive(t, c, 8)
+
+ // ACK for [3-5] packets.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
+ bytesRead := 2 * maxPayload
+ end := start.Add(seqnum.Size(bytesRead))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ bytesRead += 3 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-probeDone
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fcc3c5000..aeceee7e0 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -17,10 +17,12 @@ package tcp_test
import (
"bytes"
"fmt"
+ "io/ioutil"
"math"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -40,6 +42,64 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// endpointTester provides helper functions to test a tcpip.Endpoint.
+type endpointTester struct {
+ ep tcpip.Endpoint
+}
+
+// CheckReadError issues a read to the endpoint and checking for an error.
+func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
+ t.Helper()
+ res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{})
+ if got != want {
+ t.Fatalf("ep.Read = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" {
+ t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff)
+ }
+}
+
+// CheckRead issues a read to the endpoint and checking for a success, returning
+// the data read.
+func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ return buf.Bytes()
+}
+
+// CheckReadFull reads from the endpoint for exactly count bytes.
+func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ var done int
+ for done < count {
+ res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{})
+ if err == tcpip.ErrWouldBlock {
+ // Wait for receive to be notified.
+ select {
+ case <-notifyRead:
+ case <-time.After(timeout):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+ continue
+ } else if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ done += res.Count
+ }
+ return buf.Bytes()
+}
+
const (
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -75,9 +135,6 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.LastError(); err != tcpip.ErrAborted {
- t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
- }
// Call Connect again to retreive the handshake failure status
// and stats updates.
@@ -267,7 +324,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) {
}
// Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
- sent := stats.ICMP.V4PacketsSent
+ sent := stats.ICMP.V4.PacketsSent
if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
}
@@ -743,9 +800,7 @@ func TestSimpleReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -765,11 +820,7 @@ func TestSimpleReceive(t *testing.T) {
}
// Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1383,9 +1434,8 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.Cleanup()
c.Create(-1)
- bindToDevice := tcpip.BindToDeviceOption(test.device)
- if err := c.EP.SetSockOpt(&bindToDevice); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err)
+ if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err)
}
// Start connection attempt.
waitEntry, _ := waiter.NewChannelEntry(nil)
@@ -1496,14 +1546,11 @@ func TestSynSent(t *testing.T) {
t.Fatal("timed out waiting for packet to arrive")
}
+ ept := endpointTester{c.EP}
if test.reset {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ ept.CheckReadError(t, tcpip.ErrConnectionRefused)
} else {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
- }
+ ept.CheckReadError(t, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
@@ -1528,9 +1575,8 @@ func TestOutOfOrderReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1555,9 +1601,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send the first 3 bytes now.
c.SendPacket(data[:3], &context.Headers{
@@ -1570,24 +1614,7 @@ func TestOutOfOrderReceive(t *testing.T) {
})
// Receive data.
- read := make([]byte, 0, 6)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- if err == tcpip.ErrWouldBlock {
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
- continue
- }
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
+ read := ept.CheckReadFull(t, 6, ch, 5*time.Second)
// Check that we received the data in proper order.
if !bytes.Equal(data, read) {
@@ -1612,9 +1639,8 @@ func TestOutOfOrderFlood(t *testing.T) {
rcvBufSz := math.MaxUint16
c.CreateConnected(789, 30000, rcvBufSz)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1689,9 +1715,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1758,9 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1841,17 +1865,14 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
@@ -1869,10 +1890,8 @@ func TestFullWindowReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err := c.EP.Read(nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
// the provided buffer value by tcp.SegOverheadFactor to calculate the actual
@@ -1909,11 +1928,7 @@ func TestFullWindowReceive(t *testing.T) {
)
// Receive data and check it.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1935,6 +1950,85 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+// Test the stack receive window advertisement on receiving segments smaller than
+// segment overhead. It tests for the right edge of the window to not grow when
+// the endpoint is not being read from.
+func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize,
+ Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
+ }
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+
+ c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Bump up the receive buffer size such that, when the receive window grows,
+ // the scaled window exceeds maxUint16.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err)
+ }
+
+ // Keep the payload size < segment overhead and such that it is a multiple
+ // of the window scaled value. This enables the test to perform equality
+ // checks on the incoming receive window.
+ payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale))
+ payloadLen := seqnum.Size(len(payload))
+ iss := seqnum.Value(789)
+ seqNum := iss.Add(1)
+
+ // Send payload to the endpoint and return the advertised receive window
+ // from the endpoint.
+ getIncomingRcvWnd := func() uint32 {
+ c.SendPacket(payload, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ Flags: header.TCPFlagAck,
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(payloadLen)
+
+ pkt := c.GetPacket()
+ return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
+ }
+
+ // Read the advertised receive window with the ACK for payload.
+ rcvWnd := getIncomingRcvWnd()
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Read the data so that the subsequent ACK from the endpoint
+ // grows the right edge of the window.
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
+ t.Fatalf("c.EP.Read: %s", err)
+ }
+
+ // Check if we have received max uint16 as our advertised
+ // scaled window now after a read above.
+ maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
+ if got, want := getIncomingRcvWnd(), maxRcv; got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+
+ // Check if the subsequent ACK to our send has not grown the right edge of
+ // the window.
+ if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
+ t.Fatalf("got incomingRcvwnd %d want %d", got, want)
+ }
+}
+
func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -1953,9 +2047,9 @@ func TestNoWindowShrinking(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
+
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
seqNum := iss.Add(1)
@@ -1977,11 +2071,7 @@ func TestNoWindowShrinking(t *testing.T) {
}
// Read the 1 byte payload we just sent.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- if got, want := payload, v; !bytes.Equal(got, want) {
+ if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) {
t.Fatalf("got data: %v, want: %v", got, want)
}
@@ -2054,24 +2144,8 @@ func TestNoWindowShrinking(t *testing.T) {
),
)
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
// Receive data and check it.
- read := make([]byte, 0, rcvBufSize)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
-
+ read := ept.CheckReadFull(t, len(data), ch, 5*time.Second)
if !bytes.Equal(data, read) {
t.Fatalf("got data = %v, want = %v", read, data)
}
@@ -2495,11 +2569,11 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// we need to read at 3 packets.
sz := 0
for sz < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- sz += len(v)
+ sz += res.Count
}
checker.IPv4(t, c.GetPacket(),
@@ -2532,10 +2606,10 @@ func TestSegmentMerging(t *testing.T) {
{
"cork",
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, true)
+ ep.SocketOptions().SetCorkOption(true)
},
func(ep tcpip.Endpoint) {
- ep.SetSockOptBool(tcpip.CorkOption, false)
+ ep.SocketOptions().SetCorkOption(false)
},
},
}
@@ -2627,7 +2701,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -2675,7 +2749,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptBool(tcpip.DelayOption, true)
+ c.EP.SocketOptions().SetDelayOption(true)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -2708,7 +2782,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOptBool(tcpip.DelayOption, false)
+ c.EP.SocketOptions().SetDelayOption(false)
// Check that data is received.
second := c.GetPacket()
@@ -2745,8 +2819,8 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }},
}
for _, test := range tests {
@@ -3194,10 +3268,15 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, _, err := c.EP.Read(nil); err {
+ switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err {
case tcpip.ErrWouldBlock:
select {
case <-ch:
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+ break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
@@ -3207,14 +3286,10 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
- }
+
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
-
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}
@@ -3386,7 +3461,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
- idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}}
// Expect two retransmitted packets, and that all packets received have
// unique IPv4 ID values.
for i := 0; i <= 2; i++ {
@@ -4089,9 +4164,8 @@ func TestReadAfterClosedState(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -4149,35 +4223,31 @@ func TestReadAfterClosedState(t *testing.T) {
}
// Check that peek works.
- peekBuf := make([]byte, 10)
- n, _, err := c.EP.Peek([][]byte{peekBuf})
+ var peekBuf bytes.Buffer
+ res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
- peekBuf = peekBuf[:n]
- if !bytes.Equal(data, peekBuf) {
- t.Fatalf("got data = %v, want = %v", peekBuf, data)
+ if got, want := res.Count, len(data); got != want {
+ t.Fatalf("res.Count = %d, want %d", got, want)
}
-
- // Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
+ if !bytes.Equal(data, peekBuf.Bytes()) {
+ t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data)
}
+ // Receive data.
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
// Now that we drained the queue, check that functions fail with the
// right error code.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
-
- if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
}
}
@@ -4193,9 +4263,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4205,9 +4273,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4218,9 +4284,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4233,9 +4297,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4246,9 +4308,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4261,9 +4321,7 @@ func TestReusePort(t *testing.T) {
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
- t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
- }
+ c.EP.SocketOptions().SetReuseAddress(true)
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
@@ -4443,7 +4501,7 @@ func TestBindToDeviceOption(t *testing.T) {
name string
setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
+ getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
{"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
@@ -4453,15 +4511,13 @@ func TestBindToDeviceOption(t *testing.T) {
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ bindToDevice := int32(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
- } else if bindToDevice != testAction.getBindToDevice {
+ bindToDevice := ep.SocketOptions().GetBindToDevice()
+ if bindToDevice != testAction.getBindToDevice {
t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
}
})
@@ -4558,17 +4614,8 @@ func TestSelfConnect(t *testing.T) {
// Read back what was written.
wq.EventUnregister(&waitEntry)
wq.EventRegister(&waitEntry, waiter.EventIn)
- rd, _, err := ep.Read(nil)
- if err != nil {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
- <-notifyCh
- rd, _, err = ep.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
+ ept := endpointTester{ep}
+ rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second)
if !bytes.Equal(data, rd) {
t.Fatalf("got data = %v, want = %v", rd, data)
@@ -4656,13 +4703,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
case "dual":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
- }
+ ep.SocketOptions().SetV6Only(false)
default:
t.Fatalf("unknown network: '%s'", network)
}
@@ -4998,9 +5041,7 @@ func TestKeepalive(t *testing.T) {
if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
- t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
- }
+ c.EP.SocketOptions().SetKeepAlive(true)
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -5027,9 +5068,8 @@ func TestKeepalive(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
@@ -5118,9 +5158,7 @@ func TestKeepalive(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
@@ -5660,16 +5698,14 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
t.Fatalf("Bind failed: %s", err)
}
- // Test acceptance.
// Start listening.
listenBacklog := 1
- portOffset := uint16(0)
if err := c.EP.Listen(listenBacklog); err != nil {
t.Fatalf("Listen failed: %s", err)
}
- executeHandshake(t, c, context.TestPort+portOffset, false)
- portOffset++
+ executeHandshake(t, c, context.TestPort, false)
+
// Wait for this to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
@@ -5717,6 +5753,50 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
+func TestSYNRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send the same SYN packet multiple times. We should still get a valid SYN-ACK
+ // reply.
+ irs := seqnum.Value(789)
+ for i := 0; i < 5; i++ {
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Receive the SYN-ACK reply.
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
+}
+
func TestSynRcvdBadSeqNumber(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5971,9 +6051,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
- }
+ ept := endpointTester{ep}
+ ept.CheckReadError(t, tcpip.ErrNotConnected)
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
@@ -6074,10 +6153,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Introduce a 25ms latency by delaying the first byte.
latency := 25 * time.Millisecond
time.Sleep(latency)
- rawEP.SendPacketWithTS([]byte{1}, tsVal)
+ // Send an initial payload with atleast segment overhead size. The receive
+ // window would not grow for smaller segments.
+ rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal)
pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
+
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
@@ -6125,7 +6207,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
- _, _, err := c.EP.Read(nil)
+ _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -6249,11 +6331,11 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// to happen before we measure the new window.
totalCopied := 0
for {
- b, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
- totalCopied += len(b)
+ totalCopied += res.Count
}
// Invoke the moderation API. This is required for auto-tuning
@@ -6350,10 +6432,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T
if err != nil {
t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
}
- gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
- }
+ gotDelayOption := ep.SocketOptions().GetDelayOption()
if gotDelayOption != wantDelayOption {
t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
}
@@ -7173,9 +7252,8 @@ func TestTCPUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
@@ -7206,9 +7284,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
}
- if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
- t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
- }
+ c.EP.SocketOptions().SetKeepAlive(true)
// Set userTimeout to be the duration to be 1 keepalive
// probes. Which means that after the first probe is sent
@@ -7220,9 +7296,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
@@ -7261,9 +7336,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
@@ -7320,11 +7393,11 @@ func TestIncreaseWindowOnRead(t *testing.T) {
// defaultMTU is a good enough estimate for the MSS used for this
// connection.
for read < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- read += len(v)
+ read += res.Count
}
// After reading > MSS worth of data, we surely crossed MSS. See the ack:
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 0f9ed06cd..9e02d467d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -105,11 +106,18 @@ func TestTimeStampEnabledConnect(t *testing.T) {
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
@@ -286,11 +294,18 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
}
// Issue a read and we should data.
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index e6aa4fc4b..ee55f030c 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -592,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
+ c.EP.SocketOptions().SetV6Only(v6only)
}
// GetV6Packet reads a single packet from the link layer endpoint of the context
@@ -637,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- NextHeader: uint8(tcp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
+ PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
+ TransportProtocol: tcp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index 7981d469b..38a335840 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
+ if t.timer == nil {
+ // No cleanup needed.
+ return
+ }
t.timer.Stop()
*t = timer{}
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c78549424..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -56,6 +56,8 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9bcb918bb..5d87f3a7e 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,8 +16,9 @@ package udp
import (
"fmt"
+ "io"
+ "sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -30,10 +31,11 @@ import (
// +stateify savable
type udpPacket struct {
udpPacketEntry
- senderAddress tcpip.FullAddress
- packetInfo tcpip.IPPacketInfo
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- timestamp int64
+ senderAddress tcpip.FullAddress
+ destinationAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
// tos stores either the receiveTOS or receiveTClass value.
tos uint8
}
@@ -77,6 +79,7 @@ func (s EndpointState) String() string {
// +stateify savable
type endpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
@@ -94,22 +97,19 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState
- route stack.Route `state:"manual"`
+ route *stack.Route `state:"manual"`
dstPort uint16
- v6only bool
ttl uint8
multicastTTL uint8
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
- multicastLoop bool
portFlags ports.Flags
- bindToDevice tcpip.NICID
- broadcast bool
- noChecksum bool
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -123,17 +123,6 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
- // receiveTOS determines if the incoming IPv4 TOS header field is passed
- // as ancillary data to ControlMessages on Read.
- receiveTOS bool
-
- // receiveTClass determines if the incoming IPv6 TClass header field is
- // passed as ancillary data to ControlMessages on Read.
- receiveTClass bool
-
- // receiveIPPacketInfo determines if the packet info is returned by Read.
- receiveIPPacketInfo bool
-
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -155,8 +144,8 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// +stateify savable
@@ -186,13 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
//
// Linux defaults to TTL=1.
multicastTTL: 1,
- multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
+ e.ops.InitHandler(e)
+ e.ops.SetMulticastLoop(true)
// Override with stack defaults.
var ss stack.SendBufferSizeOption
@@ -208,6 +198,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState() returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
// UniqueID implements stack.TransportEndpoint.UniqueID.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -222,6 +226,13 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+}
+
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
@@ -233,7 +244,7 @@ func (e *endpoint) Close() {
e.mu.Lock()
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
+ switch e.EndpointState() {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
@@ -256,10 +267,13 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- e.route.Release()
+ if e.route != nil {
+ e.route.Release()
+ e.route = nil
+ }
// Update the state.
- e.state = StateClosed
+ e.setEndpointState(StateClosed)
e.mu.Unlock()
@@ -269,11 +283,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
e.rcvMu.Lock()
@@ -285,41 +298,54 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
- e.rcvMu.Unlock()
-
- if addr != nil {
- *addr = p.senderAddress
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
}
+ e.rcvMu.Unlock()
+ // Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
Timestamp: p.timestamp,
}
- e.mu.RLock()
- receiveTOS := e.receiveTOS
- receiveTClass := e.receiveTClass
- receiveIPPacketInfo := e.receiveIPPacketInfo
- e.mu.RUnlock()
- if receiveTOS {
+ if e.ops.GetReceiveTOS() {
cm.HasTOS = true
cm.TOS = p.tos
}
- if receiveTClass {
+ if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
- if receiveIPPacketInfo {
+ if e.ops.GetReceivePacketInfo() {
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
}
- return p.data.ToView(), cm, nil
+ if e.ops.GetReceiveOriginalDstAddress() {
+ cm.HasOriginalDstAddress = true
+ cm.OriginalDstAddress = p.destinationAddress
+ }
+
+ // Read Result
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: cm,
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
+ }
+
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -328,7 +354,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
//
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateConnected:
return false, nil
@@ -350,7 +376,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return true, nil
}
@@ -365,9 +391,9 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
- if isBroadcastOrMulticast(localAddr) {
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
}
@@ -382,9 +408,9 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
if err != nil {
- return stack.Route{}, 0, err
+ return nil, 0, err
}
return r, nicID, nil
}
@@ -427,7 +453,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
to := opts.To
e.mu.RLock()
- defer e.mu.RUnlock()
+ lockReleased := false
+ defer func() {
+ if lockReleased {
+ return
+ }
+ e.mu.RUnlock()
+ }()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
@@ -446,36 +478,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
}
- var route *stack.Route
- var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
- var dstPort uint16
- if to == nil {
- route = &e.route
- dstPort = e.dstPort
- resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
- // Promote lock to exclusive if using a shared route, given that it may
- // need to change in Route.Resolve() call below.
- e.mu.RUnlock()
- e.mu.Lock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- if err == nil && route.IsResolutionRequired() {
- ch, err = route.Resolve(waker)
- }
-
- e.mu.Unlock()
- e.mu.RLock()
-
- // Recheck state after lock was re-acquired.
- if e.state != StateConnected {
- err = tcpip.ErrInvalidEndpointState
- }
- return
- }
- } else {
+ route := e.route
+ dstPort := e.dstPort
+ if to != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
@@ -503,17 +508,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
defer r.Release()
- route = &r
+ route = r
dstPort = dst.Port
- resolve = route.Resolve
}
- if !e.broadcast && route.IsOutboundBroadcast() {
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
return 0, nil, tcpip.ErrBroadcastDisabled
}
if route.IsResolutionRequired() {
- if ch, err := resolve(nil); err != nil {
+ if ch, err := route.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -527,6 +531,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
+ so := e.SocketOptions()
+ if so.GetRecvError() {
+ so.QueueLocalErr(
+ tcpip.ErrMessageTooLong,
+ route.NetProto,
+ header.UDPMaximumPacketSize,
+ tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ Port: dstPort,
+ },
+ v,
+ )
+ }
return 0, nil, tcpip.ErrMessageTooLong
}
@@ -539,83 +557,41 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
+ localPort := e.ID.LocalPort
+ sendTOS := e.sendTOS
+ owner := e.owner
+ noChecksum := e.SocketOptions().GetNoChecksum()
+ lockReleased = true
+ e.mu.RUnlock()
+
+ // Do not hold lock when sending as loopback is synchronous and if the UDP
+ // datagram ends up generating an ICMP response then it can result in a
+ // deadlock where the ICMP response handling ends up acquiring this endpoint's
+ // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
+ // deadlock if another caller is trying to acquire e.mu in exclusive mode w/
+ // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
+ // lock can be eventually acquired.
+ //
+ // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
+ // locking is prohibited.
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+func (e *endpoint) OnReuseAddressSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
}
-// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
-func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v
- e.mu.Unlock()
-
- case tcpip.MulticastLoopOption:
- e.mu.Lock()
- e.multicastLoop = v
- e.mu.Unlock()
-
- case tcpip.NoChecksumOption:
- e.mu.Lock()
- e.noChecksum = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTOSOption:
- e.mu.Lock()
- e.receiveTOS = v
- e.mu.Unlock()
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrNotSupported
- }
-
- e.mu.Lock()
- e.receiveTClass = v
- e.mu.Unlock()
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.Lock()
- e.receiveIPPacketInfo = v
- e.mu.Unlock()
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.portFlags.MostRecent = v
- e.mu.Unlock()
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.portFlags.LoadBalanced = v
- e.mu.Unlock()
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.v6only = v
- }
-
- return nil
+// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+func (e *endpoint) OnReusePortSet(v bool) {
+ e.mu.Lock()
+ e.portFlags.LoadBalanced = v
+ e.mu.Unlock()
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
@@ -691,6 +667,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
@@ -737,14 +717,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
nicID := v.NIC
- // The interface address is considered not-set if it is empty or contains
- // all-zeros. The former represent the zero-value in golang, the latter the
- // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
- allZeros := header.IPv4Any
- if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
+ if v.InterfaceAddr.Unspecified() {
if nicID == 0 {
- r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err == nil {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
nicID = r.NICID()
r.Release()
}
@@ -777,10 +752,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
nicID := v.NIC
- if v.InterfaceAddr == header.IPv4Any {
+ if v.InterfaceAddr.Unspecified() {
if nicID == 0 {
- r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err == nil {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
nicID = r.NICID()
r.Release()
}
@@ -807,107 +781,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
delete(e.multicastMemberships, memToRemove)
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.mu.Lock()
- e.bindToDevice = id
- e.mu.Unlock()
-
case *tcpip.SocketDetachFilterOption:
return nil
-
- case *tcpip.LingerOption:
- e.mu.Lock()
- e.linger = *v
- e.mu.Unlock()
}
return nil
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.BroadcastOption:
- e.mu.RLock()
- v := e.broadcast
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.KeepaliveEnabledOption:
- return false, nil
-
- case tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.NoChecksumOption:
- e.mu.RLock()
- v := e.noChecksum
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTOSOption:
- e.mu.RLock()
- v := e.receiveTOS
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveTClassOption:
- // We only support this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrNotSupported
- }
-
- e.mu.RLock()
- v := e.receiveTClass
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReceiveIPPacketInfoOption:
- e.mu.RLock()
- v := e.receiveIPPacketInfo
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.ReuseAddressOption:
- e.mu.RLock()
- v := e.portFlags.MostRecent
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.portFlags.LoadBalanced
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return false, tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.RLock()
- v := e.v6only
- e.mu.RUnlock()
-
- return v, nil
-
- case tcpip.AcceptConnOption:
- return false, nil
-
- default:
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
@@ -977,16 +856,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
e.mu.Unlock()
- case *tcpip.BindToDeviceOption:
- e.mu.RLock()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
-
- case *tcpip.LingerOption:
- e.mu.RLock()
- *o = e.linger
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1046,7 +915,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -1058,7 +927,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return nil
}
var (
@@ -1081,7 +950,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = StateBound
+ e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
@@ -1089,14 +958,14 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- e.state = StateInitial
+ e.setEndpointState(StateInitial)
}
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
- e.route = stack.Route{}
+ e.route = nil
e.dstPort = 0
return nil
@@ -1114,7 +983,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
var localPort uint16
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateBound, StateConnected:
localPort = e.ID.LocalPort
@@ -1140,7 +1009,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: e.ID.LocalAddress,
@@ -1149,7 +1017,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -1157,7 +1025,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
@@ -1168,6 +1036,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
@@ -1178,12 +1047,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
- e.route = r.Clone()
+ e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
- e.state = StateConnected
+ e.setEndpointState(StateConnected)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1205,7 +1074,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != StateBound && e.state != StateConnected {
+ if state := e.EndpointState(); state != StateBound && state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -1236,27 +1105,28 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
id.LocalPort = port
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1269,7 +1139,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// wildcard (empty) address, and this is an IPv6 endpoint with v6only
// set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
@@ -1277,7 +1147,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
nicID := addr.NIC
- if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
// A local unicast address was specified, verify that it's valid.
nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicID == 0 {
@@ -1300,7 +1170,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1332,7 +1202,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
addr := e.ID.LocalAddress
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
addr = e.route.LocalAddress
}
@@ -1348,7 +1218,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1447,6 +1317,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
Addr: id.RemoteAddress,
Port: header.UDP(hdr).SourcePort(),
},
+ destinationAddress: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: header.UDP(hdr).DestinationPort(),
+ },
}
packet.data = pkt.Data
e.rcvList.PushBack(packet)
@@ -1477,28 +1352,71 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
+func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ // Linux passes the payload without the UDP header.
+ var payload []byte
+ udp := header.UDP(pkt.Data.ToView())
+ if len(udp) >= header.UDPMinimumSize {
+ payload = udp.Payload()
+ }
+
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ Payload: payload,
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: e.ID.RemoteAddress,
+ Port: e.ID.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.waiterQueue.Notify(waiter.EventErr)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
- e.mu.RLock()
- if e.state == StateConnected {
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrConnectionRefused
- e.lastErrorMu.Unlock()
- e.mu.RUnlock()
-
- e.waiterQueue.Notify(waiter.EventErr)
+ if e.EndpointState() == StateConnected {
+ var errType byte
+ var errCode byte
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ errType = byte(header.ICMPv4DstUnreachable)
+ errCode = byte(header.ICMPv4PortUnreachable)
+ case header.IPv6ProtocolNumber:
+ errType = byte(header.ICMPv6DstUnreachable)
+ errCode = byte(header.ICMPv6PortUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
+ }
+ e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt)
return
}
- e.mu.RUnlock()
}
}
// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
@@ -1518,10 +1436,16 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.Wait.
func (*endpoint) Wait() {}
-func isBroadcastOrMulticast(a tcpip.Address) bool {
- return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
}
+// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 858c99a45..13b72dc88 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != StateBound && e.state != StateConnected {
+ state := e.EndpointState()
+ if state != StateBound && state != StateConnected {
return
}
@@ -113,12 +114,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop)
+ if state == StateConnected {
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound
+ } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 14e4648cd..d7fc21f11 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
route.ResolveWith(r.pkt.SourceLinkAddress())
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c09c7aa86..c8da173f1 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -18,10 +18,12 @@ import (
"bytes"
"context"
"fmt"
+ "io/ioutil"
"math/rand"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -32,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,6 +57,7 @@ const (
stackPort = 1234
testAddr = "\x0a\x00\x00\x02"
testPort = 4096
+ invalidPort = 8192
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
@@ -295,7 +299,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
})
}
@@ -360,13 +365,9 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetV6Only(true)
} else if flow.isBroadcast() {
- if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- c.t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetBroadcast(true)
}
}
@@ -453,12 +454,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -555,7 +556,7 @@ func TestBindToDeviceOption(t *testing.T) {
name string
setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
+ getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
{"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
@@ -565,15 +566,13 @@ func TestBindToDeviceOption(t *testing.T) {
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ bindToDevice := int32(*testAction.setBindToDevice)
+ if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
- } else if bindToDevice != testAction.getBindToDevice {
+ bindToDevice := ep.SocketOptions().GetBindToDevice()
+ if bindToDevice != testAction.getBindToDevice {
t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
}
})
@@ -598,13 +597,13 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- var addr tcpip.FullAddress
- v, cm, err := c.ep.Read(&addr)
+ var buf bytes.Buffer
+ res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, cm, err = c.ep.Read(&addr)
+ res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -624,23 +623,32 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
}
if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
}
- // Check the peer address.
+ // Check the read result.
h := flow.header4Tuple(incoming)
- if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages", // ControlMessages will be checked later.
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
}
// Check the payload.
+ v := buf.Bytes()
if !bytes.Equal(payload, v) {
c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
for _, f := range checkers {
- f(c.t, cm)
+ f(c.t, res.ControlMessages)
}
c.checkEndpointReadStats(1, epstats, err)
@@ -831,8 +839,8 @@ func TestV4ReadSelfSource(t *testing.T) {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
- if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
+ if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr {
+ t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
}
})
}
@@ -974,7 +982,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
// provided.
func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, true, checkers...)
+ return testWriteAndVerifyInternal(c, flow, true, checkers...)
}
// testWriteWithoutDestination sends a packet of the given test flow from the
@@ -983,10 +991,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker
// checker functions provided.
func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
c.t.Helper()
- return testWriteInternal(c, flow, false, checkers...)
+ return testWriteAndVerifyInternal(c, flow, false, checkers...)
}
-func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
c.t.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
@@ -1008,6 +1016,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
c.checkEndpointWriteStats(1, epstats, err)
+ return payload
+}
+
+func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ payload := testWriteNoVerify(c, flow, setDest)
// Received the packet and check the payload.
b := c.getPacketAndVerify(flow, checkers...)
var udp header.UDP
@@ -1152,6 +1166,39 @@ func TestV4WriteOnConnected(t *testing.T) {
testWriteWithoutDestination(c, unicastV4)
}
+func TestWriteOnConnectedInvalidPort(t *testing.T) {
+ protocols := map[string]tcpip.NetworkProtocolNumber{
+ "ipv4": ipv4.ProtocolNumber,
+ "ipv6": ipv6.ProtocolNumber,
+ }
+ for name, pn := range protocols {
+ t.Run(name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(pn)
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
+ c.t.Fatalf("Connect failed: %s", err)
+ }
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
+ }
+ payload := buffer.View(newPayload())
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ if err != nil {
+ c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
+ }
+ if got, want := n, int64(len(payload)); got != want {
+ c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
+ }
+
+ if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused {
+ c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ }
+ })
+ }
+}
+
// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
// that is bound to a V4 multicast address.
func TestWriteOnBoundToV4Multicast(t *testing.T) {
@@ -1374,9 +1421,7 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
- t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
- }
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
NIC: 1,
@@ -1391,6 +1436,93 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
+func TestReadRecvOriginalDstAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedOriginalDstAddr tcpip.FullAddress
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%#v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
+ }
+ }
+
+ c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
+
+ testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1414,16 +1546,12 @@ func TestNoChecksum(t *testing.T) {
c.createEndpointForFlow(flow)
// Disable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(true)
// This option is effective on IPv4 only.
testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
// Enable the checksum generation.
- if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
- t.Fatalf("SetSockOptBool failed: %s", err)
- }
+ c.ep.SocketOptions().SetNoChecksum(false)
testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
})
}
@@ -1432,29 +1560,17 @@ func TestNoChecksum(t *testing.T) {
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
- stack.NetworkLinkEndpoint
+ stack.NetworkInterface
}
func (*testInterface) ID() tcpip.NICID {
return 0
}
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
func (*testInterface) Enabled() bool {
return true
}
-func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1589,13 +1705,15 @@ func TestSetTClass(t *testing.T) {
}
func TestReceiveTosTClass(t *testing.T) {
+ const RcvTOSOpt = "ReceiveTosOption"
+ const RcvTClassOpt = "ReceiveTClassOption"
+
testCases := []struct {
- name string
- getReceiveOption tcpip.SockOptBool
- tests []testFlow
+ name string
+ tests []testFlow
}{
- {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
- {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
+ {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1604,29 +1722,32 @@ func TestReceiveTosTClass(t *testing.T) {
defer c.cleanup()
c.createEndpointForFlow(flow)
- option := testCase.getReceiveOption
name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ var optionGetter func() bool
+ var optionSetter func(bool)
+ switch name {
+ case RcvTOSOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTOS
+ optionSetter = c.ep.SocketOptions().SetReceiveTOS
+ case RcvTClassOpt:
+ optionGetter = c.ep.SocketOptions().GetReceiveTClass
+ optionSetter = c.ep.SocketOptions().SetReceiveTClass
+ default:
+ t.Fatalf("unkown test variant: %s", name)
}
+
+ // Verify that setting and reading the option works.
+ v := optionGetter()
// Test for expected default value.
if v != false {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
}
want := true
- if err := c.ep.SetSockOptBool(option, want); err != nil {
- c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
- }
-
- got, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
- }
+ optionSetter(want)
+ got := optionGetter()
if got != want {
c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
}
@@ -1636,10 +1757,10 @@ func TestReceiveTosTClass(t *testing.T) {
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %s", err)
}
- switch option {
- case tcpip.ReceiveTClassOption:
+ switch name {
+ case RcvTClassOpt:
testRead(c, flow, checker.ReceiveTClass(testTOS))
- case tcpip.ReceiveTOSOption:
+ case RcvTOSOpt:
testRead(c, flow, checker.ReceiveTOS(testTOS))
default:
t.Fatalf("unknown test variant: %s", name)
@@ -1953,12 +2074,12 @@ func TestShortHeader(t *testing.T) {
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
@@ -2393,17 +2514,13 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
}
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
- t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
- }
+ ep.SocketOptions().SetBroadcast(true)
if n, _, err := ep.Write(data, opts); err != nil {
t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
}
- if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
- t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
- }
+ ep.SocketOptions().SetBroadcast(false)
if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)