summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go6
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go6
-rw-r--r--pkg/tcpip/buffer/BUILD4
-rw-r--r--pkg/tcpip/buffer/view.go28
-rw-r--r--pkg/tcpip/checker/BUILD1
-rw-r--r--pkg/tcpip/checker/checker.go292
-rw-r--r--pkg/tcpip/faketime/BUILD24
-rw-r--r--pkg/tcpip/faketime/faketime.go236
-rw-r--r--pkg/tcpip/faketime/faketime_test.go95
-rw-r--r--pkg/tcpip/header/BUILD4
-rw-r--r--pkg/tcpip/header/arp.go77
-rw-r--r--pkg/tcpip/header/eth.go4
-rw-r--r--pkg/tcpip/header/icmpv4.go56
-rw-r--r--pkg/tcpip/header/icmpv6.go79
-rw-r--r--pkg/tcpip/header/ipv4.go98
-rw-r--r--pkg/tcpip/header/ipv6.go12
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go113
-rw-r--r--pkg/tcpip/header/ipversion_test.go2
-rw-r--r--pkg/tcpip/header/parse/BUILD15
-rw-r--r--pkg/tcpip/header/parse/parse.go168
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go14
-rw-r--r--pkg/tcpip/link/fdbased/BUILD2
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go113
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go211
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go16
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go45
-rw-r--r--pkg/tcpip/link/loopback/loopback.go29
-rw-r--r--pkg/tcpip/link/muxed/BUILD1
-rw-r--r--pkg/tcpip/link/muxed/injectable.go10
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go25
-rw-r--r--pkg/tcpip/link/nested/BUILD1
-rw-r--r--pkg/tcpip/link/nested/nested.go21
-rw-r--r--pkg/tcpip/link/nested/nested_test.go8
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD1
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go18
-rw-r--r--pkg/tcpip/link/rawfile/BUILD13
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go2
-rw-r--r--pkg/tcpip/link/rawfile/errors.go8
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go53
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go32
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go51
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go139
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go14
-rw-r--r--pkg/tcpip/link/sniffer/BUILD1
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go80
-rw-r--r--pkg/tcpip/link/tun/BUILD15
-rw-r--r--pkg/tcpip/link/tun/device.go86
-rw-r--r--pkg/tcpip/link/waitable/BUILD2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go20
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go27
-rw-r--r--pkg/tcpip/network/BUILD3
-rw-r--r--pkg/tcpip/network/arp/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go188
-rw-r--r--pkg/tcpip/network/arp/arp_test.go373
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD5
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go155
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go294
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go37
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go4
-rw-r--r--pkg/tcpip/network/ip_test.go685
-rw-r--r--pkg/tcpip/network/ipv4/BUILD7
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go294
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go640
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go1303
-rw-r--r--pkg/tcpip/network/ipv6/BUILD8
-rw-r--r--pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go (renamed from pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go)2
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go579
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go597
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go1146
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go1095
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go (renamed from pkg/tcpip/stack/ndp.go)742
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go976
-rw-r--r--pkg/tcpip/network/testutil/BUILD20
-rw-r--r--pkg/tcpip/network/testutil/testutil.go144
-rw-r--r--pkg/tcpip/ports/ports.go19
-rw-r--r--pkg/tcpip/ports/ports_test.go2
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go6
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go6
-rw-r--r--pkg/tcpip/stack/BUILD48
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go753
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go77
-rw-r--r--pkg/tcpip/stack/conntrack.go462
-rw-r--r--pkg/tcpip/stack/forwarder_test.go764
-rw-r--r--pkg/tcpip/stack/headertype_string.go39
-rw-r--r--pkg/tcpip/stack/iptables.go381
-rw-r--r--pkg/tcpip/stack/iptables_state.go40
-rw-r--r--pkg/tcpip/stack/iptables_targets.go154
-rw-r--r--pkg/tcpip/stack/iptables_types.go136
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go79
-rw-r--r--pkg/tcpip/stack/ndp_test.go1667
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go333
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go1727
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go490
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2869
-rw-r--r--pkg/tcpip/stack/neighborstate_string.go44
-rw-r--r--pkg/tcpip/stack/nic.go1511
-rw-r--r--pkg/tcpip/stack/nic_test.go164
-rw-r--r--pkg/tcpip/stack/nud.go466
-rw-r--r--pkg/tcpip/stack/nud_test.go807
-rw-r--r--pkg/tcpip/stack/packet_buffer.go320
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go397
-rw-r--r--pkg/tcpip/stack/registration.go382
-rw-r--r--pkg/tcpip/stack/route.go177
-rw-r--r--pkg/tcpip/stack/stack.go608
-rw-r--r--pkg/tcpip/stack/stack_test.go1147
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go24
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go26
-rw-r--r--pkg/tcpip/stack/transport_test.go198
-rw-r--r--pkg/tcpip/tcpip.go464
-rw-r--r--pkg/tcpip/tests/integration/BUILD26
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go314
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go556
-rw-r--r--pkg/tcpip/time_unsafe.go32
-rw-r--r--pkg/tcpip/timer.go168
-rw-r--r--pkg/tcpip/timer_test.go91
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go78
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go33
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go153
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go107
-rw-r--r--pkg/tcpip/transport/tcp/BUILD6
-rw-r--r--pkg/tcpip/transport/tcp/accept.go15
-rw-r--r--pkg/tcpip/transport/tcp/connect.go86
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go57
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go536
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go18
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go236
-rw-r--r--pkg/tcpip/transport/tcp/rack.go94
-rw-r--r--pkg/tcpip/transport/tcp/rack_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go101
-rw-r--r--pkg/tcpip/transport/tcp/segment.go53
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go52
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/snd.go37
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go15
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go1882
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go29
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go197
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go5
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go177
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go155
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go702
150 files changed, 26496 insertions, 8199 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index d82ed5205..4f551cd92 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
// Accept implements net.Conn.Accept.
func (l *TCPListener) Accept() (net.Conn, error) {
- n, wq, err := l.ep.Accept()
+ n, wq, err := l.ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) {
defer l.wq.EventUnregister(&waitEntry)
for {
- n, wq, err = l.ep.Accept()
+ n, wq, err = l.ep.Accept(nil)
if err != tcpip.ErrWouldBlock {
break
@@ -541,7 +541,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
case <-notifyCh:
}
- err = ep.GetSockOpt(tcpip.ErrorOption{})
+ err = ep.LastError()
}
if err != nil {
ep.Close()
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 3c552988a..12b061def 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -61,8 +61,8 @@ func TestTimeouts(t *testing.T) {
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
})
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
@@ -104,7 +104,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er
err = ep.Connect(addr)
if err == tcpip.ErrConnectStarted {
<-ch
- err = ep.GetSockOpt(tcpip.ErrorOption{})
+ err = ep.LastError()
}
if err != nil {
return nil, err
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index 563bc78ea..c326fab54 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -14,6 +14,8 @@ go_library(
go_test(
name = "buffer_test",
size = "small",
- srcs = ["view_test.go"],
+ srcs = [
+ "view_test.go",
+ ],
library = ":buffer",
)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 9a3c5d6c3..8db70a700 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -65,6 +65,16 @@ func (v View) ToVectorisedView() VectorisedView {
return NewVectorisedView(len(v), []View{v})
}
+// IsEmpty returns whether v is of length zero.
+func (v View) IsEmpty() bool {
+ return len(v) == 0
+}
+
+// Size returns the length of v.
+func (v View) Size() int {
+ return len(v)
+}
+
// VectorisedView is a vectorised version of View using non contiguous memory.
// It supports all the convenience methods supported by View.
//
@@ -74,8 +84,8 @@ type VectorisedView struct {
size int
}
-// NewVectorisedView creates a new vectorised view from an already-allocated slice
-// of View and sets its size.
+// NewVectorisedView creates a new vectorised view from an already-allocated
+// slice of View and sets its size.
func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
@@ -160,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) {
}
// Clone returns a clone of this VectorisedView.
-// If the buffer argument is large enough to contain all the Views of this VectorisedView,
-// the method will avoid allocations and use the buffer to store the Views of the clone.
+// If the buffer argument is large enough to contain all the Views of this
+// VectorisedView, the method will avoid allocations and use the buffer to
+// store the Views of the clone.
func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
@@ -199,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) {
return newFirst, true
}
-// Size returns the size in bytes of the entire content stored in the vectorised view.
+// Size returns the size in bytes of the entire content stored in the
+// vectorised view.
func (vv *VectorisedView) Size() int {
return vv.size
}
@@ -212,6 +224,12 @@ func (vv *VectorisedView) ToView() View {
if len(vv.views) == 1 {
return vv.views[0]
}
+ return vv.ToOwnedView()
+}
+
+// ToOwnedView returns a single view containing the content of the vectorised
+// view that vv does not own.
+func (vv *VectorisedView) ToOwnedView() View {
u := make([]byte, 0, vv.size)
for _, v := range vv.views {
u = append(u, v...)
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index ed434807f..c984470e6 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -12,5 +12,6 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index ee264b726..d4d785cca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -21,6 +21,7 @@ import (
"reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -117,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.HopLimit()
}
if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
+ }
+ }
+}
+
+// IPFullLength creates a checker for the full IP packet length. The
+// expected size is checked against both the Total Length in the
+// header and the number of bytes received.
+func IPFullLength(packetLength uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ var v uint16
+ var l uint16
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TotalLength()
+ l = uint16(len(ip))
+ case header.IPv6:
+ v = ip.PayloadLength() + header.IPv6FixedHeaderSize
+ l = uint16(len(ip))
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
+ }
+ if l != packetLength {
+ t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
+ }
+ if v != packetLength {
+ t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
+ }
+ }
+}
+
+// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
+func IPv4HeaderLength(headerLength int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if hl := ip.HeaderLength(); hl != uint8(headerLength) {
+ t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
+ }
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
}
}
}
// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
+func PayloadLen(payloadLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
+ if l := len(h[0].Payload()); l != payloadLength {
+ t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
+ }
+ }
+}
+
+// IPv4Options returns a checker that checks the options in an IPv4 packet.
+func IPv4Options(want []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ options := ip.Options()
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(options) == 0 {
+ return
+ }
+ if diff := cmp.Diff(want, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
}
}
}
@@ -138,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
}
}
}
@@ -153,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
}
}
}
@@ -169,10 +234,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTClass {
- t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
- }
- if got := cm.TClass; got != want {
- t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
+ } else if got := cm.TClass; got != want {
+ t.Errorf("got cm.TClass = %d, want %d", got, want)
}
}
}
@@ -182,10 +246,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTOS {
- t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
+ t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
+ } else if got := cm.TOS; got != want {
+ t.Errorf("got cm.TOS = %d, want %d", got, want)
}
- if got := cm.TOS; got != want {
- t.Fatalf("got cm.TOS = %d, want %d", got, want)
+ }
+}
+
+// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
+// ControlMessages.
+func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPPacketInfo {
+ t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
+ } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
+ t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
}
}
}
@@ -196,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
}
}
}
@@ -222,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
@@ -249,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -285,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
@@ -304,7 +380,7 @@ func SrcPort(port uint16) TransportChecker {
t.Helper()
if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got = %d, want = %d", p, port)
}
}
}
@@ -315,7 +391,7 @@ func DstPort(port uint16) TransportChecker {
t.Helper()
if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got = %d, want = %d", p, port)
}
}
}
@@ -327,7 +403,7 @@ func NoChecksum(noChecksum bool) TransportChecker {
udp, ok := h.(header.UDP)
if !ok {
- return
+ t.Fatalf("UDP header not found in h: %T", h)
}
if b := udp.Checksum() == 0; b != noChecksum {
@@ -336,50 +412,84 @@ func NoChecksum(noChecksum bool) TransportChecker {
}
}
-// SeqNum creates a checker that checks the sequence number.
-func SeqNum(seq uint32) TransportChecker {
+// TCPSeqNum creates a checker that checks the sequence number.
+func TCPSeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
}
}
}
-// AckNum creates a checker that checks the ack number.
-func AckNum(seq uint32) TransportChecker {
+// TCPAckNum creates a checker that checks the ack number.
+func TCPAckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
}
}
}
-// Window creates a checker that checks the tcp window.
-func Window(window uint16) TransportChecker {
+// TCPWindow creates a checker that checks the tcp window.
+func TCPWindow(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in hdr : %T", h)
}
if w := tcp.WindowSize(); w != window {
- t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
+ t.Errorf("Bad window, got %d, want %d", w, window)
+ }
+ }
+}
+
+// TCPWindowGreaterThanEq creates a checker that checks that the TCP window
+// is greater than or equal to the provided value.
+func TCPWindowGreaterThanEq(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ t.Fatalf("TCP header not found in h: %T", h)
+ }
+
+ if w := tcp.WindowSize(); w < window {
+ t.Errorf("Bad window, got %d, want > %d", w, window)
+ }
+ }
+}
+
+// TCPWindowLessThanEq creates a checker that checks that the tcp window
+// is less than or equal to the provided value.
+func TCPWindowLessThanEq(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ t.Fatalf("TCP header not found in h: %T", h)
+ }
+
+ if w := tcp.WindowSize(); w > window {
+ t.Errorf("Bad window, got %d, want < %d", w, window)
}
}
}
@@ -391,7 +501,7 @@ func TCPFlags(flags uint8) TransportChecker {
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); f != flags {
@@ -408,7 +518,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); (f & mask) != (flags & mask) {
@@ -446,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
}
foundMSS = true
i += 4
@@ -456,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
}
foundWS = true
i += 3
@@ -505,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
@@ -543,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -563,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
}
}
}
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
-// contain any SACK blocks in the TCP options.
+// TCPNoSACKBlockChecker creates a checker that verifies that the segment does
+// not contain any SACK blocks in the TCP options.
func TCPNoSACKBlockChecker() TransportChecker {
return TCPSACKBlockChecker(nil)
}
@@ -633,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -649,8 +759,8 @@ func Payload(want []byte) TransportChecker {
}
}
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
-// potentially additional ICMPv4 header fields.
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
+// and potentially additional ICMPv4 header fields.
func ICMPv4(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
@@ -678,25 +788,91 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
-func ICMPv4Code(want byte) TransportChecker {
+func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
+func ICMPv4Ident(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Ident(); got != want {
+ t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
+func ICMPv4Seq(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Sequence(); got != want {
+ t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
+// This assumes that the payload exactly makes up the rest of the slice.
+func ICMPv4Checksum() TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ heldChecksum := icmpv4.Checksum()
+ icmpv4.SetChecksum(0)
+ newChecksum := ^header.Checksum(icmpv4, 0)
+ icmpv4.SetChecksum(heldChecksum)
+ if heldChecksum != newChecksum {
+ t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
+ }
+ }
+}
+
+// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
+func ICMPv4Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ payload := icmpv4.Payload()
+ if diff := cmp.Diff(want, payload); diff != "" {
+ t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
@@ -736,25 +912,25 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
-func ICMPv6Code(want byte) TransportChecker {
+func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD
new file mode 100644
index 000000000..114d43df3
--- /dev/null
+++ b/pkg/tcpip/faketime/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "faketime",
+ srcs = ["faketime.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_dpjacques_clockwork//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "faketime_test",
+ size = "small",
+ srcs = [
+ "faketime_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip/faketime",
+ ],
+)
diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go
new file mode 100644
index 000000000..f7a4fbde1
--- /dev/null
+++ b/pkg/tcpip/faketime/faketime.go
@@ -0,0 +1,236 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package faketime provides a fake clock that implements tcpip.Clock interface.
+package faketime
+
+import (
+ "container/heap"
+ "sync"
+ "time"
+
+ "github.com/dpjacques/clockwork"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// NullClock implements a clock that never advances.
+type NullClock struct{}
+
+var _ tcpip.Clock = (*NullClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (*NullClock) NowNanoseconds() int64 {
+ return 0
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (*NullClock) NowMonotonic() int64 {
+ return 0
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer {
+ return nil
+}
+
+// ManualClock implements tcpip.Clock and only advances manually with Advance
+// method.
+type ManualClock struct {
+ clock clockwork.FakeClock
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // times is min-heap of times. A heap is used for quick retrieval of the next
+ // upcoming time of scheduled work.
+ times *timeHeap
+
+ // waitGroups stores one WaitGroup for all work scheduled to execute at the
+ // same time via AfterFunc. This allows parallel execution of all functions
+ // passed to AfterFunc scheduled for the same time.
+ waitGroups map[time.Time]*sync.WaitGroup
+}
+
+// NewManualClock creates a new ManualClock instance.
+func NewManualClock() *ManualClock {
+ return &ManualClock{
+ clock: clockwork.NewFakeClock(),
+ times: &timeHeap{},
+ waitGroups: make(map[time.Time]*sync.WaitGroup),
+ }
+}
+
+var _ tcpip.Clock = (*ManualClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (mc *ManualClock) NowNanoseconds() int64 {
+ return mc.clock.Now().UnixNano()
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (mc *ManualClock) NowMonotonic() int64 {
+ return mc.NowNanoseconds()
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ until := mc.clock.Now().Add(d)
+ wg := mc.addWait(until)
+ return &manualTimer{
+ clock: mc,
+ until: until,
+ timer: mc.clock.AfterFunc(d, func() {
+ defer wg.Done()
+ f()
+ }),
+ }
+}
+
+// addWait adds an additional wait to the WaitGroup for parallel execution of
+// all work scheduled for t. Returns a reference to the WaitGroup modified.
+func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup {
+ mc.mu.RLock()
+ wg, ok := mc.waitGroups[t]
+ mc.mu.RUnlock()
+
+ if ok {
+ wg.Add(1)
+ return wg
+ }
+
+ mc.mu.Lock()
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
+
+ wg = &sync.WaitGroup{}
+ wg.Add(1)
+
+ mc.mu.Lock()
+ mc.waitGroups[t] = wg
+ mc.mu.Unlock()
+
+ return wg
+}
+
+// removeWait removes a wait from the WaitGroup for parallel execution of all
+// work scheduled for t.
+func (mc *ManualClock) removeWait(t time.Time) {
+ mc.mu.RLock()
+ defer mc.mu.RUnlock()
+
+ wg := mc.waitGroups[t]
+ wg.Done()
+}
+
+// Advance executes all work that have been scheduled to execute within d from
+// the current time. Blocks until all work has completed execution.
+func (mc *ManualClock) Advance(d time.Duration) {
+ // Block until all the work is done
+ until := mc.clock.Now().Add(d)
+ for {
+ mc.mu.Lock()
+ if mc.times.Len() == 0 {
+ mc.mu.Unlock()
+ break
+ }
+
+ t := heap.Pop(mc.times).(time.Time)
+ if t.After(until) {
+ // No work to do
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
+ break
+ }
+ mc.mu.Unlock()
+
+ diff := t.Sub(mc.clock.Now())
+ mc.clock.Advance(diff)
+
+ mc.mu.RLock()
+ wg := mc.waitGroups[t]
+ mc.mu.RUnlock()
+
+ wg.Wait()
+
+ mc.mu.Lock()
+ delete(mc.waitGroups, t)
+ mc.mu.Unlock()
+ }
+ if now := mc.clock.Now(); until.After(now) {
+ mc.clock.Advance(until.Sub(now))
+ }
+}
+
+type manualTimer struct {
+ clock *ManualClock
+ timer clockwork.Timer
+
+ mu sync.RWMutex
+ until time.Time
+}
+
+var _ tcpip.Timer = (*manualTimer)(nil)
+
+// Reset implements tcpip.Timer.Reset.
+func (t *manualTimer) Reset(d time.Duration) {
+ if !t.timer.Reset(d) {
+ return
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ t.clock.removeWait(t.until)
+ t.until = t.clock.clock.Now().Add(d)
+ t.clock.addWait(t.until)
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (t *manualTimer) Stop() bool {
+ if !t.timer.Stop() {
+ return false
+ }
+
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+
+ t.clock.removeWait(t.until)
+ return true
+}
+
+type timeHeap []time.Time
+
+var _ heap.Interface = (*timeHeap)(nil)
+
+func (h timeHeap) Len() int {
+ return len(h)
+}
+
+func (h timeHeap) Less(i, j int) bool {
+ return h[i].Before(h[j])
+}
+
+func (h timeHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+func (h *timeHeap) Push(x interface{}) {
+ *h = append(*h, x.(time.Time))
+}
+
+func (h *timeHeap) Pop() interface{} {
+ last := (*h)[len(*h)-1]
+ *h = (*h)[:len(*h)-1]
+ return last
+}
diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go
new file mode 100644
index 000000000..c2704df2c
--- /dev/null
+++ b/pkg/tcpip/faketime/faketime_test.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package faketime_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+)
+
+func TestManualClockAdvance(t *testing.T) {
+ const timeout = time.Millisecond
+ clock := faketime.NewManualClock()
+ start := clock.NowMonotonic()
+ clock.Advance(timeout)
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want {
+ t.Errorf("got = %d, want = %d", got, want)
+ }
+}
+
+func TestManualClockAfterFunc(t *testing.T) {
+ const (
+ timeout1 = time.Millisecond // timeout for counter1
+ timeout2 = 2 * time.Millisecond // timeout for counter2
+ )
+ tests := []struct {
+ name string
+ advance time.Duration
+ wantCounter1 int
+ wantCounter2 int
+ }{
+ {
+ name: "before timeout1",
+ advance: timeout1 - 1,
+ wantCounter1: 0,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout1",
+ advance: timeout1,
+ wantCounter1: 1,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout2",
+ advance: timeout2,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ {
+ name: "after timeout2",
+ advance: timeout2 + 1,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ counter1 := 0
+ counter2 := 0
+ clock.AfterFunc(timeout1, func() {
+ counter1++
+ })
+ clock.AfterFunc(timeout2, func() {
+ counter2++
+ })
+ start := clock.NowMonotonic()
+ clock.Advance(test.advance)
+ if got, want := counter1, test.wantCounter1; got != want {
+ t.Errorf("got counter1 = %d, want = %d", got, want)
+ }
+ if got, want := counter2, test.wantCounter2; got != want {
+ t.Errorf("got counter2 = %d, want = %d", got, want)
+ }
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want {
+ t.Errorf("got elapsed = %d, want = %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 0cde694dc..d87797617 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -48,7 +48,7 @@ go_test(
"//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -64,6 +64,6 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
index 718a4720a..83189676e 100644
--- a/pkg/tcpip/header/arp.go
+++ b/pkg/tcpip/header/arp.go
@@ -14,14 +14,33 @@
package header
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// ARPProtocolNumber is the ARP network protocol number.
ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
// ARPSize is the size of an IPv4-over-Ethernet ARP packet.
- ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+ ARPSize = 28
+)
+
+// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header.
+type ARPHardwareType uint16
+
+// Typical ARP HardwareType values. Some of the constants have to be specific
+// values as they are egressed on the wire in the HTYPE field of an ARP header.
+const (
+ ARPHardwareNone ARPHardwareType = 0
+ // ARPHardwareEther specifically is the HTYPE for Ethernet as specified
+ // in the IANA list here:
+ //
+ // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2
+ ARPHardwareEther ARPHardwareType = 1
+ ARPHardwareLoopback ARPHardwareType = 2
)
// ARPOp is an ARP opcode.
@@ -36,54 +55,64 @@ const (
// ARP is an ARP packet stored in a byte array as described in RFC 826.
type ARP []byte
-func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
-func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
-func (a ARP) hardwareAddressSize() int { return int(a[4]) }
-func (a ARP) protocolAddressSize() int { return int(a[5]) }
+const (
+ hTypeOffset = 0
+ protocolOffset = 2
+ haAddressSizeOffset = 4
+ protoAddressSizeOffset = 5
+ opCodeOffset = 6
+ senderHAAddressOffset = 8
+ senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize
+ targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize
+ targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize
+)
+
+func (a ARP) hardwareAddressType() ARPHardwareType {
+ return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:]))
+}
+
+func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) }
+func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) }
+func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) }
// Op is the ARP opcode.
-func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) }
// SetOp sets the ARP opcode.
func (a ARP) SetOp(op ARPOp) {
- a[6] = uint8(op >> 8)
- a[7] = uint8(op)
+ binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op))
}
// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
func (a ARP) SetIPv4OverEthernet() {
- a[0], a[1] = 0, 1 // htypeEthernet
- a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
- a[4] = 6 // macSize
- a[5] = uint8(IPv4AddressSize)
+ binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther))
+ binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber))
+ a[haAddressSizeOffset] = EthernetAddressSize
+ a[protoAddressSizeOffset] = uint8(IPv4AddressSize)
}
// HardwareAddressSender is the link address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressSender() []byte {
- const s = 8
- return a[s : s+6]
+ return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressSender is the protocol address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressSender() []byte {
- const s = 8 + 6
- return a[s : s+4]
+ return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize]
}
// HardwareAddressTarget is the link address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressTarget() []byte {
- const s = 8 + 6 + 4
- return a[s : s+6]
+ return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressTarget is the protocol address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressTarget() []byte {
- const s = 8 + 6 + 4 + 6
- return a[s : s+4]
+ return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize]
}
// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
@@ -91,10 +120,8 @@ func (a ARP) IsValid() bool {
if len(a) < ARPSize {
return false
}
- const htypeEthernet = 1
- const macSize = 6
- return a.hardwareAddressSpace() == htypeEthernet &&
+ return a.hardwareAddressType() == ARPHardwareEther &&
a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
- a.hardwareAddressSize() == macSize &&
+ a.hardwareAddressSize() == EthernetAddressSize &&
a.protocolAddressSize() == IPv4AddressSize
}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index b1e92d2d7..eaface8cb 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -53,6 +53,10 @@ const (
// (all bits set to 0).
unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ // EthernetBroadcastAddress is an ethernet address that addresses every node
+ // on a local link.
+ EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff")
+
// unicastMulticastFlagMask is the mask of the least significant bit in
// the first octet (in network byte order) of an ethernet address that
// determines whether the ethernet address is a unicast or multicast. If
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 7908c5744..504408878 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -31,6 +31,27 @@ const (
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
+ // ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an
+ // errant packet's transport layer that an ICMP error type packet should
+ // attempt to send as per RFC 792 (see each type) and RFC 1122
+ // section 3.2.2 which states:
+ // Every ICMP error message includes the Internet header and at
+ // least the first 8 data octets of the datagram that triggered
+ // the error; more than 8 octets MAY be sent; this header and data
+ // MUST be unchanged from the received datagram.
+ //
+ // RFC 792 shows:
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Type | Code | Checksum |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | unused |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Internet Header + 64 bits of Original Data Datagram |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ ICMPv4MinimumErrorPayloadSize = 8
+
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
@@ -39,21 +60,28 @@ const (
icmpv4ChecksumOffset = 2
// icmpv4MTUOffset is the offset of the MTU field
- // in a ICMPv4FragmentationNeeded message.
+ // in an ICMPv4FragmentationNeeded message.
icmpv4MTUOffset = 6
// icmpv4IdentOffset is the offset of the ident field
- // in a ICMPv4EchoRequest/Reply message.
+ // in an ICMPv4EchoRequest/Reply message.
icmpv4IdentOffset = 4
+ // icmpv4PointerOffset is the offset of the pointer field
+ // in an ICMPv4ParamProblem message.
+ icmpv4PointerOffset = 4
+
// icmpv4SequenceOffset is the offset of the sequence field
- // in a ICMPv4EchoRequest/Reply message.
+ // in an ICMPv4EchoRequest/Reply message.
icmpv4SequenceOffset = 6
)
// ICMPv4Type is the ICMP type field described in RFC 792.
type ICMPv4Type byte
+// ICMPv4Code is the ICMP code field described in RFC 792.
+type ICMPv4Code byte
+
// Typical values of ICMPv4Type defined in RFC 792.
const (
ICMPv4EchoReply ICMPv4Type = 0
@@ -69,13 +97,23 @@ const (
ICMPv4InfoReply ICMPv4Type = 16
)
-// Values for ICMP code as defined in RFC 792.
+// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
const (
- ICMPv4TTLExceeded = 0
- ICMPv4PortUnreachable = 3
- ICMPv4FragmentationNeeded = 4
+ ICMPv4TTLExceeded ICMPv4Code = 0
)
+// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792.
+const (
+ ICMPv4NetUnreachable ICMPv4Code = 0
+ ICMPv4HostUnreachable ICMPv4Code = 1
+ ICMPv4ProtoUnreachable ICMPv4Code = 2
+ ICMPv4PortUnreachable ICMPv4Code = 3
+ ICMPv4FragmentationNeeded ICMPv4Code = 4
+)
+
+// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed.
+const ICMPv4UnusedCode ICMPv4Code = 0
+
// Type is the ICMP type field.
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
@@ -83,10 +121,10 @@ func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
-func (b ICMPv4) Code() byte { return b[1] }
+func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
-func (b ICMPv4) SetCode(c byte) { b[1] = c }
+func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index c7ee2de57..6be31beeb 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -54,9 +54,17 @@ const (
// address.
ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
- // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
+ // ICMPv6ErrorHeaderSize is the size of an ICMP error packet header,
+ // as per RFC 4443, Apendix A, item 4 and the errata.
+ // ... all ICMP error messages shall have exactly
+ // 32 bits of type-specific data, so that receivers can reliably find
+ // the embedded invoking packet even when they don't recognize the
+ // ICMP message Type.
+ ICMPv6ErrorHeaderSize = 8
+
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
@@ -69,6 +77,10 @@ const (
// in an ICMPv6 message.
icmpv6ChecksumOffset = 2
+ // icmpv6PointerOffset is the offset of the pointer
+ // in an ICMPv6 Parameter problem message.
+ icmpv6PointerOffset = 4
+
// icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
// PacketTooBig message.
icmpv6MTUOffset = 4
@@ -89,10 +101,10 @@ const (
NDPHopLimit = 255
)
-// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+// ICMPv6Type is the ICMP type field described in RFC 4443.
type ICMPv6Type byte
-// Typical values of ICMPv6Type defined in RFC 4443.
+// Values for use in the Type field of ICMPv6 packet from RFC 4433.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
@@ -110,11 +122,54 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// Values for ICMP code as defined in RFC 4443.
+// IsErrorType returns true if the receiver is an ICMP error type.
+func (typ ICMPv6Type) IsErrorType() bool {
+ // Per RFC 4443 section 2.1:
+ // ICMPv6 messages are grouped into two classes: error messages and
+ // informational messages. Error messages are identified as such by a
+ // zero in the high-order bit of their message Type field values. Thus,
+ // error messages have message types from 0 to 127; informational
+ // messages have message types from 128 to 255.
+ return typ&0x80 == 0
+}
+
+// ICMPv6Code is the ICMP Code field described in RFC 4443.
+type ICMPv6Code byte
+
+// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443
+// section 3.1.
+const (
+ ICMPv6NetworkUnreachable ICMPv6Code = 0
+ ICMPv6Prohibited ICMPv6Code = 1
+ ICMPv6BeyondScope ICMPv6Code = 2
+ ICMPv6AddressUnreachable ICMPv6Code = 3
+ ICMPv6PortUnreachable ICMPv6Code = 4
+ ICMPv6Policy ICMPv6Code = 5
+ ICMPv6RejectRoute ICMPv6Code = 6
+)
+
+// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3.
+const (
+ ICMPv6HopLimitExceeded ICMPv6Code = 0
+ ICMPv6ReassemblyTimeout ICMPv6Code = 1
+)
+
+// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
const (
- ICMPv6PortUnreachable = 4
+ // ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
+ ICMPv6ErroneousHeader ICMPv6Code = 0
+
+ // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
+ ICMPv6UnknownHeader ICMPv6Code = 1
+
+ // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
+ ICMPv6UnknownOption ICMPv6Code = 2
)
+// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
+// the code field. (Types not mentioned above.)
+const ICMPv6UnusedCode ICMPv6Code = 0
+
// Type is the ICMP type field.
func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
@@ -122,10 +177,20 @@ func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
-func (b ICMPv6) Code() byte { return b[1] }
+func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
-func (b ICMPv6) SetCode(c byte) { b[1] = c }
+func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
+
+// TypeSpecific returns the type specific data field.
+func (b ICMPv6) TypeSpecific() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
+}
+
+// SetTypeSpecific sets the type specific data field.
+func (b ICMPv6) SetTypeSpecific(val uint32) {
+ binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
+}
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 62ac932bb..4c6e4be64 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -16,10 +16,29 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// RFC 971 defines the fields of the IPv4 header on page 11 using the following
+// diagram: ("Figure 4")
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Identification |Flags| Fragment Offset |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Time to Live | Protocol | Header Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Source Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Destination Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options | Padding |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
const (
versIHL = 0
tos = 1
@@ -33,6 +52,7 @@ const (
checksum = 10
srcAddr = 12
dstAddr = 16
+ options = 20
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
@@ -76,11 +96,13 @@ type IPv4Fields struct {
// IPv4 represents an ipv4 header stored in a byte array.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
-// Always call IsValid() to validate an instance of IPv4 before using other methods.
+// Always call IsValid() to validate an instance of IPv4 before using other
+// methods.
type IPv4 []byte
const (
- // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet;
+ // i.e. a packet header with no options.
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
@@ -88,6 +110,16 @@ const (
// units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
+ // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload.
+ //
+ // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4
+ // header size). But RFC 791 section 3.2 discusses the design of the IPv4
+ // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of
+ // 65,536 octets. Note that this is consistent with the the datagram total
+ // length field (of course, the header is counted in the total length and not
+ // in the fragments)."
+ IPv4MaximumPayloadSize = 65536
+
// MinIPFragmentPayloadSize is the minimum number of payload bytes that
// the first fragment must carry when an IPv4 packet is fragmented.
MinIPFragmentPayloadSize = 8
@@ -101,6 +133,11 @@ const (
// IPv4Version is the version of the ipv4 protocol.
IPv4Version = 4
+ // IPv4AllSystems is the all systems IPv4 multicast address as per
+ // IANA's IPv4 Multicast Address Space Registry. See
+ // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml.
+ IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01"
+
// IPv4Broadcast is the broadcast address of the IPv4 procotol.
IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
@@ -135,13 +172,44 @@ func IPVersion(b []byte) int {
if len(b) < versIHL+1 {
return -1
}
- return int(b[versIHL] >> 4)
+ return int(b[versIHL] >> ipVersionShift)
}
+// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits
+// of the first byte, and is counted in multiples of 4 bytes.
+//
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// (...)
+// Version: 4 bits
+// The Version field indicates the format of the internet header. This
+// document describes version 4.
+//
+// IHL: 4 bits
+// Internet Header Length is the length of the internet header in 32
+// bit words, and thus points to the beginning of the data. Note that
+// the minimum value for a correct header is 5.
+//
+const (
+ ipVersionShift = 4
+ ipIHLMask = 0x0f
+ IPv4IHLStride = 4
+)
+
// HeaderLength returns the value of the "header length" field of the ipv4
// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
- return (b[versIHL] & 0xf) * 4
+ return (b[versIHL] & ipIHLMask) * IPv4IHLStride
+}
+
+// SetHeaderLength sets the value of the "Internet Header Length" field.
+func (b IPv4) SetHeaderLength(hdrLen uint8) {
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize))
+ }
+ b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask)
}
// ID returns the value of the identifier field of the ipv4 header.
@@ -195,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// Options returns a a buffer holding the options.
+func (b IPv4) Options() []byte {
+ hdrLen := b.HeaderLength()
+ return b[options:hdrLen:hdrLen]
+}
+
// TransportProtocol implements Network.TransportProtocol.
func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.Protocol())
@@ -220,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
+// SetTTL sets the "Time to Live" field of the IPv4 header.
+func (b IPv4) SetTTL(v byte) {
+ b[ttl] = v
+}
+
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
@@ -260,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 {
// Encode encodes all the fields of the ipv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
- b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b.SetHeaderLength(i.IHL)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
@@ -310,3 +389,12 @@ func IsV4MulticastAddress(addr tcpip.Address) bool {
}
return (addr[0] & 0xf0) == 0xe0
}
+
+// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback
+// address (belongs to 127.0.0.0/8 subnet). See RFC 1122 section 3.2.1.3.
+func IsV4LoopbackAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return addr[0] == 0x7f
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 4f367fe4c..ef454b313 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -34,6 +34,9 @@ const (
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
+
+ // IPv6FixedHeaderSize is the size of the fixed header.
+ IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -69,11 +72,15 @@ type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
- IPv6MinimumSize = 40
+ IPv6MinimumSize = IPv6FixedHeaderSize
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
+ // IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per
+ // RFC 8200 Section 4.5.
+ IPv6MaximumPayloadSize = 65535
+
// IPv6ProtocolNumber is IPv6's network protocol number.
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
@@ -98,6 +105,9 @@ const (
// section 5.
IPv6MinimumMTU = 1280
+ // IPv6Loopback is the IPv6 Loopback address.
+ IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
// known as the unspecified address.
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 3499d8399..583c2c5d3 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
// obtained before modification is no longer used.
type IPv6OptionsExtHdrOptionsIterator struct {
reader bytes.Reader
+
+ // optionOffset is the number of bytes from the first byte of the
+ // options field to the beginning of the current option.
+ optionOffset uint32
+
+ // nextOptionOffset is the offset of the next option.
+ nextOptionOffset uint32
+}
+
+// OptionOffset returns the number of bytes parsed while processing the
+// option field of the current Extension Header.
+func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
+ return i.optionOffset
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
@@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
// the options data, or an error occured.
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
for {
+ i.optionOffset = i.nextOptionOffset
temp, err := i.reader.ReadByte()
if err != nil {
// If we can't read the first byte of a new option, then we know the
@@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// know the option does not have Length and Data fields. End processing of
// the Pad1 option and continue processing the buffer as a new option.
if id == ipv6Pad1ExtHdrOptionIdentifier {
+ i.nextOptionOffset = i.optionOffset + 1
continue
}
@@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
}
- // Special-case the variable length padding option to avoid a copy.
- if id == ipv6PadNExtHdrOptionIdentifier {
- // Do we have enough bytes in the reader for the PadN option?
- if n := i.reader.Len(); n < int(length) {
- // Reset the reader to effectively consume the remaining buffer.
- i.reader.Reset(nil)
-
- // We return the same error as if we failed to read a non-padding option
- // so consumers of this iterator don't need to differentiate between
- // padding and non-padding options.
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
- }
+ // Do we have enough bytes in the reader for the next option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
+ switch id {
+ case ipv6PadNExtHdrOptionIdentifier:
+ // Special-case the variable length padding option to avoid a copy.
if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
-
- // End processing of the PadN option and continue processing the buffer as
- // a new option.
continue
- }
-
- bytes := make([]byte, length)
- if n, err := io.ReadFull(&i.reader, bytes); err != nil {
- // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
- // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
- // Length field found in the option.
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
+ default:
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
}
-
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
-
- return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
}
@@ -382,6 +396,29 @@ type IPv6PayloadIterator struct {
// Indicates to the iterator that it should return the remaining payload as a
// raw payload on the next call to Next.
forceRaw bool
+
+ // headerOffset is the offset of the beginning of the current extension
+ // header starting from the beginning of the fixed header.
+ headerOffset uint32
+
+ // parseOffset is the byte offset into the current extension header of the
+ // field we are currently examining. It can be added to the header offset
+ // if the absolute offset within the packet is required.
+ parseOffset uint32
+
+ // nextOffset is the offset of the next header.
+ nextOffset uint32
+}
+
+// HeaderOffset returns the offset to the start of the extension
+// header most recently processed.
+func (i IPv6PayloadIterator) HeaderOffset() uint32 {
+ return i.headerOffset
+}
+
+// ParseOffset returns the number of bytes successfully parsed.
+func (i IPv6PayloadIterator) ParseOffset() uint32 {
+ return i.headerOffset + i.parseOffset
}
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
@@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa
nextHdrIdentifier: nextHdrIdentifier,
payload: payload.Clone(nil),
// We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
- reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ nextOffset: IPv6FixedHeaderSize,
}
}
@@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Next is unable to return anything because the iterator has reached the end of
// the payload, or an error occured.
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ i.headerOffset = i.nextOffset
+ i.parseOffset = 0
// We could be forced to return i as a raw header when the previous header was
// a fragment extension header as the data following the fragment extension
// header may not be complete.
@@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
return IPv6RoutingExtHdr(bytes), false, nil
case IPv6FragmentExtHdrIdentifier:
var data [6]byte
- // We ignore the returned bytes becauase we know the fragment extension
+ // We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
if err != nil {
@@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
if err != nil {
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
+ i.parseOffset++
var length uint8
length, err = i.reader.ReadByte()
i.payload.TrimFront(1)
+
if err != nil {
if fragmentHdr {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
@@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
length = 0
}
+ // Make parseOffset point to the first byte of the Extension Header
+ // specific data.
+ i.parseOffset++
+
+ // length is in 8 byte chunks but doesn't include the first one.
+ // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
+ // in section 4.8 for new extension headers at the top of page 24.
+ // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
+ // units, not including the first 8 octets.
+ i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
+
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if bytes == nil {
bytes = make([]byte, bytesLen)
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
index b5540bf66..17a49d4fa 100644
--- a/pkg/tcpip/header/ipversion_test.go
+++ b/pkg/tcpip/header/ipversion_test.go
@@ -22,7 +22,7 @@ import (
func TestIPv4(t *testing.T) {
b := header.IPv4(make([]byte, header.IPv4MinimumSize))
- b.Encode(&header.IPv4Fields{})
+ b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize})
const want = header.IPv4Version
if v := header.IPVersion(b); v != want {
diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD
new file mode 100644
index 000000000..2adee9288
--- /dev/null
+++ b/pkg/tcpip/header/parse/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "parse",
+ srcs = ["parse.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
new file mode 100644
index 000000000..5ca75c834
--- /dev/null
+++ b/pkg/tcpip/header/parse/parse.go
@@ -0,0 +1,168 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package parse provides utilities to parse packets.
+package parse
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// ARP populates pkt's network header with an ARP header found in
+// pkt.Data.
+//
+// Returns true if the header was successfully parsed.
+func ARP(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.NetworkHeader().Consume(header.ARPSize)
+ if ok {
+ pkt.NetworkProtocolNumber = header.ARPProtocolNumber
+ }
+ return ok
+}
+
+// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network
+// header with the IPv4 header.
+//
+// Returns true if the header was successfully parsed.
+func IPv4(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return false
+ }
+ ipHdr := header.IPv4(hdr)
+
+ // Header may have options, determine the true header length.
+ headerLen := int(ipHdr.HeaderLength())
+ if headerLen < header.IPv4MinimumSize {
+ // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
+ // order for the packet to be valid. Figure out if we want to reject this
+ // case.
+ headerLen = header.IPv4MinimumSize
+ }
+ hdr, ok = pkt.NetworkHeader().Consume(headerLen)
+ if !ok {
+ return false
+ }
+ ipHdr = header.IPv4(hdr)
+
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ return true
+}
+
+// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network
+// header with the IPv6 header.
+func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return 0, 0, 0, false, false
+ }
+ ipHdr := header.IPv6(hdr)
+
+ // dataClone consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ views := [8]buffer.View{}
+ dataClone := pkt.Data.Clone(views[:])
+ dataClone.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+
+ // Iterate over the IPv6 extensions to find their length.
+ var nextHdr tcpip.TransportProtocolNumber
+ var extensionsSize int
+
+traverseExtensions:
+ for {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ break
+ }
+
+ // If we exhaust the extension list, the entire packet is the IPv6 header
+ // and (possibly) extensions.
+ if done {
+ extensionsSize = dataClone.Size()
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6FragmentExtHdr:
+ if fragID == 0 && fragOffset == 0 && !fragMore {
+ fragID = extHdr.ID()
+ fragOffset = extHdr.FragmentOffset()
+ fragMore = extHdr.More()
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // We've found the payload after any extensions.
+ extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
+ break traverseExtensions
+
+ default:
+ // Any other extension is a no-op, keep looping until we find the payload.
+ }
+ }
+
+ // Put the IPv6 header with extensions in pkt.NetworkHeader().
+ hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
+ if !ok {
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ }
+ ipHdr = header.IPv6(hdr)
+ pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+
+ return nextHdr, fragID, fragOffset, fragMore, true
+}
+
+// UDP parses a UDP packet found in pkt.Data and populates pkt's transport
+// header with the UDP header.
+//
+// Returns true if the header was successfully parsed.
+func UDP(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
+ pkt.TransportProtocolNumber = header.UDPProtocolNumber
+ return ok
+}
+
+// TCP parses a TCP packet found in pkt.Data and populates pkt's transport
+// header with the TCP header.
+//
+// Returns true if the header was successfully parsed.
+func TCP(pkt *stack.PacketBuffer) bool {
+ // TCP header is variable length, peek at it first.
+ hdrLen := header.TCPMinimumSize
+ hdr, ok := pkt.Data.PullUp(hdrLen)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
+ // packets.
+ hdrLen = offset
+ }
+
+ _, ok = pkt.TransportHeader().Consume(hdrLen)
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
+ return ok
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 9339d637f..98bdd29db 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "math"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -55,6 +56,10 @@ const (
// UDPMinimumSize is the minimum size of a valid UDP packet.
UDPMinimumSize = 8
+ // UDPMaximumSize is the maximum size of a valid UDP packet. The length field
+ // in the UDP header is 16 bits as per RFC 768.
+ UDPMaximumSize = math.MaxUint16
+
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index b8b93e78e..39ca774ef 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 20b183da0..c95aef63c 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -273,7 +274,9 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := PacketInfo{
- Pkt: &stack.PacketBuffer{Data: vv},
+ Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }),
Proto: 0,
GSO: nil,
}
@@ -296,3 +299,12 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
e.q.RemoveNotify(handle)
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index aa6db9aea..10072eac1 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -15,6 +15,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/binary",
+ "//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -36,5 +37,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
"//pkg/tcpip/stack",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f34082e1a..975309fc8 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -45,6 +45,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -385,32 +386,40 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
- eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
- pkt.LinkHeader = buffer.View(eth)
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
+
+ var builder iovec.Builder
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
if gso != nil {
- vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
if gso.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
@@ -430,49 +439,28 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
- return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
+ builder.Add(vnetHdrBuf)
}
- if pkt.Data.Size() == 0 {
- return rawfile.NonBlockingWrite(fd, pkt.Header.View())
- }
- if pkt.Header.UsedLength() == 0 {
- return rawfile.NonBlockingWrite(fd, pkt.Data.ToView())
+ for _, v := range pkt.Views() {
+ builder.Add(v)
}
-
- return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil)
+ return rawfile.NonBlockingWriteIovec(fd, builder.Build())
}
func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) {
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
- var ethHdrBuf []byte
- iovLen := 0
if e.hdrSize > 0 {
- // Add ethernet header if needed.
- ethHdrBuf = make([]byte, header.EthernetMinimumSize)
- eth := header.Ethernet(ethHdrBuf)
- ethHdr := &header.EthernetFields{
- DstAddr: pkt.EgressRoute.RemoteLinkAddress,
- Type: pkt.NetworkProtocolNumber,
- }
-
- // Preserve the src address if it's set in the route.
- if pkt.EgressRoute.LocalLinkAddress != "" {
- ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
- iovLen++
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
- vnetHdr := virtioNetHdr{}
var vnetHdrBuf []byte
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ vnetHdr := virtioNetHdr{}
if pkt.GSOOptions != nil {
- vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
if pkt.GSOOptions.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
@@ -491,45 +479,18 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
}
}
vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
- iovLen++
}
- iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views()))
+ var builder iovec.Builder
+ builder.Add(vnetHdrBuf)
+ for _, v := range pkt.Views() {
+ builder.Add(v)
+ }
+ iovecs := builder.Build()
+
var mmsgHdr rawfile.MMsgHdr
mmsgHdr.Msg.Iov = &iovecs[0]
- iovecIdx := 0
- if vnetHdrBuf != nil {
- v := &iovecs[iovecIdx]
- v.Base = &vnetHdrBuf[0]
- v.Len = uint64(len(vnetHdrBuf))
- iovecIdx++
- }
- if ethHdrBuf != nil {
- v := &iovecs[iovecIdx]
- v.Base = &ethHdrBuf[0]
- v.Len = uint64(len(ethHdrBuf))
- iovecIdx++
- }
- pktSize := uint64(0)
- // Encode L3 Header
- v := &iovecs[iovecIdx]
- hdr := &pkt.Header
- hdrView := hdr.View()
- v.Base = &hdrView[0]
- v.Len = uint64(len(hdrView))
- pktSize += v.Len
- iovecIdx++
-
- // Now encode the Transport Payload.
- pktViews := pkt.Data.Views()
- for i := range pktViews {
- vec := &iovecs[iovecIdx]
- iovecIdx++
- vec.Base = &pktViews[i][0]
- vec.Len = uint64(len(pktViews[i]))
- pktSize += vec.Len
- }
- mmsgHdr.Msg.Iovlen = uint64(iovecIdx)
+ mmsgHdr.Msg.Iovlen = uint64(len(iovecs))
mmsgHdrs = append(mmsgHdrs, mmsgHdr)
}
@@ -626,6 +587,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
// to the FD, but does not read from it. All reads come from injected packets.
type InjectableEndpoint struct {
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index eaee7e5d7..709f829c8 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -26,6 +26,7 @@ import (
"time"
"unsafe"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -43,9 +44,36 @@ const (
)
type packetInfo struct {
- raddr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- contents *stack.PacketBuffer
+ Raddr tcpip.LinkAddress
+ Proto tcpip.NetworkProtocolNumber
+ Contents *stack.PacketBuffer
+}
+
+type packetContents struct {
+ LinkHeader buffer.View
+ NetworkHeader buffer.View
+ TransportHeader buffer.View
+ Data buffer.View
+}
+
+func checkPacketInfoEqual(t *testing.T, got, want packetInfo) {
+ t.Helper()
+ if diff := cmp.Diff(
+ want, got,
+ cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents {
+ if pk == nil {
+ return nil
+ }
+ return &packetContents{
+ LinkHeader: pk.LinkHeader().View(),
+ NetworkHeader: pk.NetworkHeader().View(),
+ TransportHeader: pk.TransportHeader().View(),
+ Data: pk.Data.ToView(),
+ }
+ }),
+ ); diff != "" {
+ t.Errorf("unexpected packetInfo (-want +got):\n%s", diff)
+ }
}
type context struct {
@@ -107,6 +135,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin
c.ch <- packetInfo{remote, protocol, pkt}
}
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
@@ -155,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
RemoteLinkAddress: raddr,
}
- // Build header.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
- b := hdr.Prepend(100)
- for i := range b {
- b[i] = uint8(rand.Intn(256))
+ // Build payload.
+ payload := buffer.NewView(plen)
+ if _, err := rand.Read(payload); err != nil {
+ t.Fatalf("rand.Read(payload): %s", err)
}
- // Build payload and write.
- payload := make(buffer.View, plen)
- for i := range payload {
- payload[i] = uint8(rand.Intn(256))
+ // Build packet buffer.
+ const netHdrLen = 100
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen,
+ Data: payload.ToVectorisedView(),
+ })
+ pkt.Hash = hash
+
+ // Build header.
+ b := pkt.NetworkHeader().Push(netHdrLen)
+ if _, err := rand.Read(b); err != nil {
+ t.Fatalf("rand.Read(b): %s", err)
}
- want := append(hdr.View(), payload...)
+
+ // Write.
+ want := append(append(buffer.View(nil), b...), payload...)
var gso *stack.GSO
if gsoMaxSize != 0 {
gso = &stack.GSO{
@@ -179,11 +220,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- Hash: hash,
- }); err != nil {
+ if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -292,13 +329,14 @@ func TestPreserveSrcAddress(t *testing.T) {
LocalLinkAddress: baddr,
}
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{
- Header: hdr,
- Data: buffer.VectorisedView{},
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength().
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ Data: buffer.VectorisedView{},
+ })
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -327,24 +365,25 @@ func TestDeliverPacket(t *testing.T) {
defer c.cleanup()
// Build packet.
- b := make([]byte, plen)
- all := b
- for i := range b {
- b[i] = uint8(rand.Intn(256))
+ all := make([]byte, plen)
+ if _, err := rand.Read(all); err != nil {
+ t.Fatalf("rand.Read(all): %s", err)
}
-
- var hdr header.Ethernet
- if !eth {
- // So that it looks like an IPv4 packet.
- b[0] = 0x40
- } else {
- hdr = make(header.Ethernet, header.EthernetMinimumSize)
+ // Make it look like an IPv4 packet.
+ all[0] = 0x40
+
+ wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ Data: buffer.NewViewFromBytes(all).ToVectorisedView(),
+ })
+ if eth {
+ hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize))
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
Type: proto,
})
- all = append(hdr, b...)
+ all = append(hdr, all...)
}
// Write packet via the file descriptor.
@@ -356,24 +395,15 @@ func TestDeliverPacket(t *testing.T) {
select {
case pi := <-c.ch:
want := packetInfo{
- raddr: raddr,
- proto: proto,
- contents: &stack.PacketBuffer{
- Data: buffer.View(b).ToVectorisedView(),
- LinkHeader: buffer.View(hdr),
- },
+ Raddr: raddr,
+ Proto: proto,
+ Contents: wantPkt,
}
if !eth {
- want.proto = header.IPv4ProtocolNumber
- want.raddr = ""
- }
- // want.contents.Data will be a single
- // view, so make pi do the same for the
- // DeepEqual check.
- pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView()
- if !reflect.DeepEqual(want, pi) {
- t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
+ want.Proto = header.IPv4ProtocolNumber
+ want.Raddr = ""
}
+ checkPacketInfoEqual(t, pi, want)
case <-time.After(10 * time.Second):
t.Fatalf("Timed out waiting for packet")
}
@@ -500,3 +530,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) {
}
}
+
+// fakeNetworkDispatcher delivers packets to pkts.
+type fakeNetworkDispatcher struct {
+ pkts []*stack.PacketBuffer
+}
+
+func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ d.pkts = append(d.pkts, pkt)
+}
+
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
+func TestDispatchPacketFormat(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ newDispatcher func(fd int, e *endpoint) (linkDispatcher, error)
+ }{
+ {
+ name: "readVDispatcher",
+ newDispatcher: newReadVDispatcher,
+ },
+ {
+ name: "recvMMsgDispatcher",
+ newDispatcher: newRecvMMsgDispatcher,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ // Create a socket pair to send/recv.
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer syscall.Close(fds[0])
+ defer syscall.Close(fds[1])
+
+ data := []byte{
+ // Ethernet header.
+ 1, 2, 3, 4, 5, 60,
+ 1, 2, 3, 4, 5, 61,
+ 8, 0,
+ // Mock network header.
+ 40, 41, 42, 43,
+ }
+ err = syscall.Sendmsg(fds[1], data, nil, nil, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create and run dispatcher once.
+ sink := &fakeNetworkDispatcher{}
+ d, err := test.newDispatcher(fds[0], &endpoint{
+ hdrSize: header.EthernetMinimumSize,
+ dispatcher: sink,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if ok, err := d.dispatch(); !ok || err != nil {
+ t.Fatalf("d.dispatch() = %v, %v", ok, err)
+ }
+
+ // Verify packet.
+ if got, want := len(sink.pkts), 1; got != want {
+ t.Fatalf("len(sink.pkts) = %d, want %d", got, want)
+ }
+ pkt := sink.pkts[0]
+ if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want {
+ t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want)
+ }
+ if got, want := pkt.Data.Size(), 4; got != want {
+ t.Errorf("pkt.Data.Size() = %d, want %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 2dfd29aa9..c475dda20 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -18,6 +18,7 @@ package fdbased
import (
"encoding/binary"
+ "fmt"
"syscall"
"golang.org/x/sys/unix"
@@ -170,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(pkt)
+ eth := header.Ethernet(pkt)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -190,10 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{
- Data: buffer.View(pkt).ToVectorisedView(),
- LinkHeader: buffer.View(eth),
+ pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(pkt).ToVectorisedView(),
})
+ if d.e.hdrSize > 0 {
+ if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok {
+ panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize))
+ }
+ }
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf)
return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index f04738cfb..8c3ca86d6 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
d.allocateViews(BufConfig)
n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
- if err != nil {
+ if n == 0 || err != nil {
return false, err
}
if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
@@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
// isn't used and it isn't in a view.
n -= virtioNetHdrSize
}
- if n <= d.e.hdrSize {
- return false, nil
- }
+
+ used := d.capViews(n, BufConfig)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
+ })
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize])
+ hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize)
+ if !ok {
+ return false, nil
+ }
+ eth := header.Ethernet(hdr)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -138,13 +143,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- used := d.capViews(n, BufConfig)
- pkt := &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
- LinkHeader: buffer.View(eth),
- }
- pkt.Data.TrimFront(d.e.hdrSize)
-
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
@@ -268,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
n -= virtioNetHdrSize
}
- if n <= d.e.hdrSize {
- return false, nil
- }
+
+ used := d.capViews(k, int(n), BufConfig)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
+ })
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[k][0])
+ hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize)
+ if !ok {
+ return false, nil
+ }
+ eth := header.Ethernet(hdr)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -295,12 +298,6 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- used := d.capViews(k, int(n), BufConfig)
- pkt := &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
- LinkHeader: buffer.View(eth),
- }
- pkt.Data.TrimFront(d.e.hdrSize)
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 568c6874f..38aa694e4 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -77,16 +77,16 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
+ // Construct data as the unparsed portion for the loopback packet.
+ data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
// Because we're immediately turning around and writing the packet back
// to the rx path, we intentionally don't preserve the remote and local
// link addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: data,
})
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt)
return nil
}
@@ -98,18 +98,25 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
// There should be an ethernet header at the beginning of vv.
- hdr, ok := vv.PullUp(header.EthernetMinimumSize)
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
if !ok {
// Reject the packet if it's shorter than an ethernet header.
return tcpip.ErrBadAddress
}
linkHeader := header.Ethernet(hdr)
- vv.TrimFront(len(linkHeader))
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{
- Data: vv,
- LinkHeader: buffer.View(linkHeader),
- })
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt)
return nil
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareLoopback
+}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 82b441b79..e7493e5c5 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index c69d6b7e9..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -18,6 +18,7 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() {
}
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unsupported operation")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 0744f66d6..3e4afcdad 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -46,14 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) {
func TestInjectableEndpointDispatch(t *testing.T) {
endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- hdr := buffer.NewPrependable(1)
- hdr.Prepend(1)[0] = 0xFA
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: 1,
+ Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
- })
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
@@ -67,13 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) {
func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- hdr := buffer.NewPrependable(1)
- hdr.Prepend(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buffer.NewView(0).ToVectorisedView(),
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: 1,
+ Data: buffer.NewView(0).ToVectorisedView(),
})
+ pkt.TransportHeader().Push(1)[0] = 0xFA
+ packetRoute := stack.Route{RemoteAddress: dstIP}
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
if err != nil {
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
index bdd5276ad..2cdb23475 100644
--- a/pkg/tcpip/link/nested/BUILD
+++ b/pkg/tcpip/link/nested/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 2998f9c4f..d40de54df 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -60,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -129,3 +140,13 @@ func (e *Endpoint) GSOMaxSize() uint32 {
}
return 0
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.child.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
index c1a219f02..c1f9d308c 100644
--- a/pkg/tcpip/link/nested/nested_test.go
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd
d.count++
}
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNestedLinkEndpoint(t *testing.T) {
const emptyAddress = tcpip.LinkAddress("")
@@ -83,7 +87,7 @@ func TestNestedLinkEndpoint(t *testing.T) {
t.Error("After attach, nestedEP.IsAttached() = false, want = true")
}
- nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if disp.count != 1 {
t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count)
}
@@ -97,7 +101,7 @@ func TestNestedLinkEndpoint(t *testing.T) {
}
disp.count = 0
- nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if disp.count != 0 {
t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count)
}
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 054c213bc..1d0079bd6 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index b5dfb7850..fc1e34fc7 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -106,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
@@ -193,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
@@ -207,3 +215,13 @@ func (e *endpoint) Wait() {
e.wg.Wait()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 14b527bc2..6c410c5a6 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,3 +18,14 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "rawfile_test",
+ srcs = [
+ "errors_test.go",
+ ],
+ library = "rawfile",
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 99313ee25..5db4bf12b 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -14,7 +14,7 @@
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.16
+// +build !go1.17
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index a0a873c84..604868fd8 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error
// *tcpip.Error.
//
// Valid, but unrecognized errnos will be translated to
-// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+// tcpip.ErrInvalidEndpointState (EINVAL).
func TranslateErrno(e syscall.Errno) *tcpip.Error {
- if err := translations[e]; err != nil {
- return err
+ if e > 0 && e < syscall.Errno(len(translations)) {
+ if err := translations[e]; err != nil {
+ return err
+ }
}
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
new file mode 100644
index 000000000..e4cdc66bd
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestTranslateErrno(t *testing.T) {
+ for _, test := range []struct {
+ errno syscall.Errno
+ translated *tcpip.Error
+ }{
+ {
+ errno: syscall.Errno(0),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(maxErrno),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(514),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.EEXIST,
+ translated: tcpip.ErrDuplicateAddress,
+ },
+ } {
+ got := TranslateErrno(test.errno)
+ if got != test.translated {
+ t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 69de6eb3e..f4c32c2da 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -66,38 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
return nil
}
-// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a
-// single syscall. It fails if partial data is written.
-func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
- // If there is no second and third buffer, issue a regular write.
- if len(b2) == 0 && len(b3) == 0 {
- return NonBlockingWrite(fd, b1)
- }
-
- // Build the iovec that represents them and issue a writev syscall.
- iovec := [3]syscall.Iovec{
- {
- Base: &b1[0],
- Len: uint64(len(b1)),
- },
- {
- Base: &b2[0],
- Len: uint64(len(b2)),
- },
- }
- iovecLen := uintptr(2)
-
- if len(b3) > 0 {
- iovecLen++
- iovec[2].Base = &b3[0]
- iovec[2].Len = uint64(len(b3))
- }
-
+// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall.
+// It fails if partial data is written.
+func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
+ iovecLen := uintptr(len(iovec))
_, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
if e != 0 {
return TranslateErrno(e)
}
-
return nil
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 0374a2441..7fb8a6c49 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -183,27 +183,33 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Add the ethernet header here.
- eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
- pkt.LinkHeader = buffer.View(eth)
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
- v := pkt.Data.ToView()
+ views := pkt.Views()
// Transmit the packet.
e.mu.Lock()
- ok := e.tx.transmit(pkt.Header.View(), v)
+ ok := e.tx.transmit(views...)
e.mu.Unlock()
if !ok {
@@ -220,10 +226,10 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- v := vv.ToView()
+ views := vv.Views()
// Transmit the packet.
e.mu.Lock()
- ok := e.tx.transmit(v, buffer.View{})
+ ok := e.tx.transmit(views...)
e.mu.Unlock()
if !ok {
@@ -269,16 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
rxb[i].Size = e.bufferSize
}
- if n < header.EthernetMinimumSize {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
continue
}
+ eth := header.Ethernet(hdr)
// Send packet up the stack.
- eth := header.Ethernet(b[:header.EthernetMinimumSize])
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{
- Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(),
- LinkHeader: buffer.View(eth),
- })
+ d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
}
// Clean state.
@@ -287,3 +295,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
e.completed.Done()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 28a2e88ba..22d5c97f1 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
@@ -262,21 +266,23 @@ func TestSimpleSend(t *testing.T) {
for iters := 1000; iters > 0; iters-- {
func() {
+ hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000)
+
// Prepare and send packet.
- n := rand.Intn(10000)
- hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength()))
- hdrBuf := hdr.Prepend(n)
+ hdrBuf := buffer.NewView(hdrLen)
randomFill(hdrBuf)
- n = rand.Intn(10000)
- buf := buffer.NewView(n)
- randomFill(buf)
+ data := buffer.NewView(dataLen)
+ randomFill(data)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })
+ copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -313,7 +319,7 @@ func TestSimpleSend(t *testing.T) {
// Compare contents skipping the ethernet header added by the
// endpoint.
- merged := append(hdrBuf, buf...)
+ merged := append(hdrBuf, data...)
if uint32(len(contents)) < pi.Size {
t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
}
@@ -340,14 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
LocalLinkAddress: newLocalLinkAddress,
}
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- hdr := buffer.NewPrependable(header.EthernetMinimumSize)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ })
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{
- Header: hdr,
- }); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -399,12 +405,12 @@ func TestFillTxQueue(t *testing.T) {
// until the tx queue if full.
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -418,11 +424,11 @@ func TestFillTxQueue(t *testing.T) {
}
// Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != want {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -446,11 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
@@ -469,11 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// until the tx queue if full.
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -487,11 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
}
// Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != want {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -513,11 +519,11 @@ func TestFillTxMemory(t *testing.T) {
// we fill the memory.
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -532,11 +538,11 @@ func TestFillTxMemory(t *testing.T) {
}
// Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
})
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
if want := tcpip.ErrWouldBlock; err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
@@ -560,11 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Each packet is uses up one buffer, so write as many as possible
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -575,23 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Attempt to write a two-buffer packet. It must fail.
{
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: uu,
- }); err != want {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buffer.NewView(bufferSize).ToVectorisedView(),
+ })
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
// Attempt to write the one-buffer packet again. It must succeed.
{
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index 6b8d7859d..44f421c2d 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -18,6 +18,7 @@ import (
"math"
"syscall"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -76,9 +77,9 @@ func (t *tx) cleanup() {
syscall.Munmap(t.data)
}
-// transmit sends a packet made up of up to two buffers. Returns a boolean that
-// specifies whether the packet was successfully transmitted.
-func (t *tx) transmit(a, b []byte) bool {
+// transmit sends a packet made of bufs. Returns a boolean that specifies
+// whether the packet was successfully transmitted.
+func (t *tx) transmit(bufs ...buffer.View) bool {
// Pull completions from the tx queue and add their buffers back to the
// pool so that we can reuse them.
for {
@@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool {
}
bSize := t.bufs.entrySize
- total := uint32(len(a) + len(b))
+ total := uint32(0)
+ for _, data := range bufs {
+ total += uint32(len(data))
+ }
bufCount := (total + bSize - 1) / bSize
// Allocate enough buffers to hold all the data.
@@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool {
// Copy data into allocated buffers.
nBuf := buf
var dBuf []byte
- for _, data := range [][]byte{a, b} {
+ for _, data := range bufs {
for len(data) > 0 {
if len(dBuf) == 0 {
dBuf = t.data[nBuf.Offset:][:nBuf.Size]
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 7cbc305e7..4aac12a8c 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/nested",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index d9cd4e83a..560477926 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -123,13 +124,18 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
logPacket(prefix, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
- totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
+ totalLength := pkt.Size()
length := totalLength
if max := int(e.maxPCAPLen); length > max {
length = max
@@ -150,12 +156,11 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
length -= n
}
}
- write(pkt.Header.View())
- for _, view := range pkt.Data.Views() {
+ for _, v := range pkt.Views() {
if length == 0 {
break
}
- write(view)
+ write(v)
}
}
}
@@ -180,9 +185,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- e.dumpPacket("send", nil, 0, &stack.PacketBuffer{
+ e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- })
+ }))
return e.Endpoint.WriteRawPacket(vv)
}
@@ -191,53 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var transProto uint8
src := tcpip.Address("unknown")
dst := tcpip.Address("unknown")
- id := 0
- size := uint16(0)
+ var size uint16
+ var id uint32
var fragmentOffset uint16
var moreFragments bool
- // Create a clone of pkt, including any headers if present. Avoid allocating
- // backing memory for the clone.
- views := [8]buffer.View{}
- vv := buffer.NewVectorisedView(0, views[:0])
- vv.AppendView(pkt.Header.View())
- vv.Append(pkt.Data)
-
+ // Clone the packet buffer to not modify the original.
+ //
+ // We don't clone the original packet buffer so that the new packet buffer
+ // does not have any of its headers set.
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())})
switch protocol {
case header.IPv4ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv4MinimumSize)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return
}
- ipv4 := header.IPv4(hdr)
+
+ ipv4 := header.IPv4(pkt.NetworkHeader().View())
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- vv.TrimFront(int(ipv4.HeaderLength()))
- id = int(ipv4.ID())
+ id = uint32(ipv4.ID())
case header.IPv6ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return
}
- ipv6 := header.IPv6(hdr)
+
+ ipv6 := header.IPv6(pkt.NetworkHeader().View())
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
- transProto = ipv6.NextHeader()
+ transProto = uint8(proto)
size = ipv6.PayloadLength()
- vv.TrimFront(header.IPv6MinimumSize)
+ id = fragID
+ moreFragments = fragMore
+ fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- hdr, ok := vv.PullUp(header.ARPSize)
- if !ok {
+ if parse.ARP(pkt) {
return
}
- vv.TrimFront(header.ARPSize)
- arp := header.ARP(hdr)
+
+ arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
"%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
@@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
break
}
@@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
if !ok {
break
}
@@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.UDPProtocolNumber:
transName = "udp"
- hdr, ok := vv.PullUp(header.UDPMinimumSize)
- if !ok {
+ if ok := parse.UDP(pkt); !ok {
break
}
- udp := header.UDP(hdr)
+
+ udp := header.UDP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.TCPProtocolNumber:
transName = "tcp"
- hdr, ok := vv.PullUp(header.TCPMinimumSize)
- if !ok {
+ if ok := parse.TCP(pkt); !ok {
break
}
- tcp := header.TCP(hdr)
+
+ tcp := header.TCP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
break
}
- if offset > vv.Size() && !moreFragments {
- details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size())
+ if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size)
break
}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index e0db6cf54..0243424f6 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -1,17 +1,32 @@
load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "tun_endpoint_refs",
+ out = "tun_endpoint_refs.go",
+ package = "tun",
+ prefix = "tunEndpoint",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "tunEndpoint",
+ },
+)
+
go_library(
name = "tun",
srcs = [
"device.go",
"protocol.go",
+ "tun_endpoint_refs.go",
"tun_unsafe.go",
],
visibility = ["//visibility:public"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
"//pkg/refs",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 6bc9033d0..b6ddbe81e 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -18,7 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -64,14 +64,14 @@ func (d *Device) beforeSave() {
}
// Release implements fs.FileOperations.Release.
-func (d *Device) Release() {
+func (d *Device) Release(ctx context.Context) {
d.mu.Lock()
defer d.mu.Unlock()
// Decrease refcount if there is an endpoint associated with this file.
if d.endpoint != nil {
d.endpoint.RemoveNotify(d.notifyHandle)
- d.endpoint.DecRef()
+ d.endpoint.DecRef(ctx)
d.endpoint = nil
}
}
@@ -134,11 +134,13 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// 2. Creating a new NIC.
id := tcpip.NICID(s.UniqueID())
+ // TODO(gvisor.dev/1486): enable leak check for tunEndpoint.
endpoint := &tunEndpoint{
Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
stack: s,
nicID: id,
name: name,
+ isTap: prefix == "tap",
}
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
@@ -213,12 +215,11 @@ func (d *Device) Write(data []byte) (int64, error) {
remote = tcpip.LinkAddress(zeroMAC[:])
}
- pkt := &stack.PacketBuffer{
- Data: buffer.View(data).ToVectorisedView(),
- }
- if ethHdr != nil {
- pkt.LinkHeader = buffer.View(ethHdr)
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: len(ethHdr),
+ Data: buffer.View(data).ToVectorisedView(),
+ })
+ copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr)
endpoint.InjectLinkAddr(protocol, remote, pkt)
return dataLen, nil
}
@@ -263,33 +264,22 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
// If the packet does not already have link layer header, and the route
// does not exist, we can't compute it. This is possibly a raw packet, tun
// device doesn't support this at the moment.
- if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" {
+ if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" {
return nil, false
}
// Ethernet header (TAP only).
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
- if info.Pkt.LinkHeader == nil {
- hdr := &header.EthernetFields{
- SrcAddr: info.Route.LocalLinkAddress,
- DstAddr: info.Route.RemoteLinkAddress,
- Type: info.Proto,
- }
- if hdr.SrcAddr == "" {
- hdr.SrcAddr = d.endpoint.LinkAddress()
- }
-
- eth := make(header.Ethernet, header.EthernetMinimumSize)
- eth.Encode(hdr)
- vv.AppendView(buffer.View(eth))
- } else {
- vv.AppendView(info.Pkt.LinkHeader)
+ if info.Pkt.LinkHeader().View().IsEmpty() {
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
+ vv.AppendView(info.Pkt.LinkHeader().View())
}
// Append upper headers.
- vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):]))
+ vv.AppendView(info.Pkt.NetworkHeader().View())
+ vv.AppendView(info.Pkt.TransportHeader().View())
// Append data payload.
vv.Append(info.Pkt.Data)
@@ -341,18 +331,52 @@ func (d *Device) WriteNotify() {
// It is ref-counted as multiple opening files can attach to the same NIC.
// The last owner is responsible for deleting the NIC.
type tunEndpoint struct {
+ tunEndpointRefs
*channel.Endpoint
- refs.AtomicRefCount
-
stack *stack.Stack
nicID tcpip.NICID
name string
+ isTap bool
}
-// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
-func (e *tunEndpoint) DecRef() {
- e.DecRefWithDestructor(func() {
+// DecRef decrements refcount of e, removing NIC if it reaches 0.
+func (e *tunEndpoint) DecRef(ctx context.Context) {
+ e.tunEndpointRefs.DecRef(func() {
e.stack.RemoveNIC(e.nicID)
})
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.isTap {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 0956d2c65..ee84c3d96 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/gate",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
@@ -25,6 +26,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 949b3f2b2..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -59,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() {
// Wait implements stack.LinkEndpoint.Wait.
func (e *Endpoint) Wait() {}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 63bf40562..94827fc56 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -39,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress,
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -81,29 +86,39 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
return nil
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unimplemented")
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{})
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -120,21 +135,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 6a4839fb8..59710352b 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -9,13 +9,16 @@ go_test(
"ip_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
],
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index eddf7b725..b40dde96b 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/stack",
],
)
@@ -28,5 +29,6 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7f27a840d..b47a7be51 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -15,20 +15,15 @@
// Package arp implements the ARP network protocol. It is used to resolve
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
-//
-// To use it in the networking stack, pass arp.NewProtocol() as one of the
-// network protocols when calling stack.New. Then add an "arp" address to every
-// NIC on the stack that should respond to ARP requests. That is:
-//
-// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
-// // handle err
-// }
package arp
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -40,45 +35,74 @@ const (
ProtocolAddress = tcpip.Address("arp")
)
-// endpoint implements stack.NetworkEndpoint.
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- protocol *protocol
- nicID tcpip.NICID
+ stack.AddressableEndpointState
+
+ protocol *protocol
+
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ nic stack.NetworkInterface
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
}
-// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
-func (e *endpoint) DefaultTTL() uint8 {
- return 0
+func (e *endpoint) Enable() *tcpip.Error {
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ e.setEnabled(true)
+ return nil
}
-func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
- return lmtu - uint32(e.MaxHeaderLength())
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
}
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+// setEnabled sets the enabled status for the endpoint.
+func (e *endpoint) setEnabled(v bool) {
+ if v {
+ atomic.StoreUint32(&e.enabled, 1)
+ } else {
+ atomic.StoreUint32(&e.enabled, 0)
+ }
}
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &stack.NetworkEndpointID{ProtocolAddress}
+func (e *endpoint) Disable() {
+ e.setEnabled(false)
}
-func (e *endpoint) PrefixLen() int {
+// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
+func (e *endpoint) DefaultTTL() uint8 {
return 0
}
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.linkEP.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
@@ -86,7 +110,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return e.protocol.Number()
+ return ProtocolNumber
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
@@ -99,7 +123,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
- h := header.ARP(pkt.NetworkHeader)
+ if !e.isEnabled() {
+ return
+ }
+
+ h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
return
}
@@ -107,25 +135,58 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
- return // we have no useful answer, ignore the request
+
+ if e.nud == nil {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ } else {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+
+ remoteAddr := tcpip.Address(h.ProtocolAddressSender())
+ remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
- packet := header.ARP(hdr.Prepend(header.ARPSize))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
+ })
+ packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- })
- fallthrough // also fill the cache from requests
+ _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+
+ if e.nud == nil {
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ return
+ }
+
+ // The solicited, override, and isRouter flags are not available for ARP;
+ // they are only available for IPv6 Neighbor Advertisements.
+ e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{
+ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies
+ // are handled equivalently to a solicited Neighbor Advertisement.
+ Solicited: true,
+ // If a different link address is received than the one cached, the entry
+ // should always go to Stale.
+ Override: false,
+ // ARP does not distinguish between router and non-router hosts.
+ IsRouter: false,
+ })
}
}
@@ -142,16 +203,16 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- if addrWithPrefix.Address != ProtocolAddress {
- return nil, tcpip.ErrBadLocalAddress
- }
- return &endpoint{
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
protocol: p,
- nicID: nicID,
- linkEP: sender,
+ nic: nic,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
- }, nil
+ nud: nud,
+ }
+ e.AddressableEndpointState.Init(e)
+ return e
}
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
@@ -160,28 +221,31 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
- RemoteLinkAddress: broadcastMAC,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
- hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
- h := header.ARP(hdr.Prepend(header.ARPSize))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize,
+ })
+ h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
copy(h.HardwareAddressSender(), linkEP.LinkAddress())
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- })
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
- return broadcastMAC, true
+ return header.EthernetBroadcastAddress, true
}
if header.IsV4MulticastAddress(addr) {
return header.EthernetAddressFromMulticastIPv4Address(addr), true
@@ -190,12 +254,12 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
// SetOption implements stack.NetworkProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.NetworkProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -207,18 +271,14 @@ func (*protocol) Wait() {}
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.ARPSize)
- if !ok {
- return 0, false, false
- }
- pkt.NetworkHeader = hdr
- pkt.Data.TrimFront(header.ARPSize)
- return 0, false, true
+ return 0, false, parse.ARP(pkt)
}
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-
// NewProtocol returns an ARP network protocol.
-func NewProtocol() stack.NetworkProtocol {
+//
+// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
+// address must be added to every NIC that should respond to ARP requests. See
+// ProtocolAddress for more details.
+func NewProtocol(*stack.Stack) stack.NetworkProtocol {
return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 66e67429c..626af975a 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -16,10 +16,12 @@ package arp_test
import (
"context"
+ "fmt"
"strconv"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -32,54 +34,192 @@ import (
)
const (
+ nicID = 1
+
+ stackAddr = tcpip.Address("\x0a\x00\x00\x01")
stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+
+ remoteAddr = tcpip.Address("\x0a\x00\x00\x02")
+ remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+
+ unknownAddr = tcpip.Address("\x0a\x00\x00\x03")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
+
+ // eventChanSize defines the size of event channels used by the neighbor
+ // cache's event dispatcher. The size chosen here needs to be sufficient to
+ // queue all the events received during tests before consumption.
+ // If eventChanSize is too small, the tests may deadlock.
+ eventChanSize = 32
+)
+
+type eventType uint8
+
+const (
+ entryAdded eventType = iota
+ entryChanged
+ entryRemoved
)
+func (t eventType) String() string {
+ switch t {
+ case entryAdded:
+ return "add"
+ case entryChanged:
+ return "change"
+ case entryRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+type eventInfo struct {
+ eventType eventType
+ nicID tcpip.NICID
+ addr tcpip.Address
+ linkAddr tcpip.LinkAddress
+ state stack.NeighborState
+}
+
+func (e eventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state)
+}
+
+// arpDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type arpDispatcher struct {
+ // C is where events are queued
+ C chan eventInfo
+}
+
+var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
+
+func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryChanged,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+ e := eventInfo{
+ eventType: entryRemoved,
+ nicID: nicID,
+ addr: addr,
+ linkAddr: linkAddr,
+ state: state,
+ }
+ d.C <- e
+}
+
+func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
+ select {
+ case got := <-d.C:
+ if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" {
+ return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
+ }
+ case <-ctx.Done():
+ return fmt.Errorf("%s for %s", ctx.Err(), want)
+ }
+ return nil
+}
+
+func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ return d.waitForEvent(ctx, want)
+}
+
+func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
+ select {
+ case event := <-d.C:
+ return event, true
+ default:
+ return eventInfo{}, false
+ }
+}
+
type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
+ s *stack.Stack
+ linkEP *channel.Endpoint
+ nudDisp *arpDispatcher
}
-func newTestContext(t *testing.T) *testContext {
+func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
+ c := stack.DefaultNUDConfigurations()
+ // Transition from Reachable to Stale almost immediately to test if receiving
+ // probes refreshes positive reachability.
+ c.BaseReachableTime = time.Microsecond
+
+ d := arpDispatcher{
+ // Create an event channel large enough so the neighbor cache doesn't block
+ // while dispatching events. Blocking could interfere with the timing of
+ // NUD transitions.
+ C: make(chan eventInfo, eventChanSize),
+ }
+
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ NUDConfigs: c,
+ NUDDisp: &d,
+ UseNeighborCache: useNeighborCache,
})
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNIC(nicID, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ if !useNeighborCache {
+ // The remote address needs to be assigned to the NIC so we can receive and
+ // verify outgoing ARP packets. The neighbor cache isn't concerned with
+ // this; the tests that use linkAddrCache expect the ARP responses to be
+ // received by the same NIC.
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
}
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
t.Fatalf("AddAddress for arp failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
}})
return &testContext{
- t: t,
- s: s,
- linkEP: ep,
+ s: s,
+ linkEP: ep,
+ nudDisp: &d,
}
}
@@ -88,7 +228,7 @@ func (c *testContext) cleanup() {
}
func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
+ c := newTestContext(t, false /* useNeighborCache */)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
@@ -103,21 +243,21 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
- })
+ }))
}
- for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ for i, address := range []tcpip.Address{stackAddr, remoteAddr} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
pi, _ := c.linkEP.ReadContext(context.Background())
if pi.Proto != arp.ProtocolNumber {
t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
- rep := header.ARP(pi.Pkt.Header.View())
+ rep := header.ARP(pi.Pkt.NetworkHeader().View())
if !rep.IsValid() {
- t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
+ t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
}
if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
@@ -134,7 +274,7 @@ func TestDirectRequest(t *testing.T) {
})
}
- inject(stackAddrBad)
+ inject(unknownAddr)
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
@@ -144,3 +284,182 @@ func TestDirectRequest(t *testing.T) {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
}
+
+func TestDirectRequestWithNeighborCache(t *testing.T) {
+ c := newTestContext(t, true /* useNeighborCache */)
+ defer c.cleanup()
+
+ tests := []struct {
+ name string
+ senderAddr tcpip.Address
+ senderLinkAddr tcpip.LinkAddress
+ targetAddr tcpip.Address
+ isValid bool
+ }{
+ {
+ name: "Loopback",
+ senderAddr: stackAddr,
+ senderLinkAddr: stackLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "Remote",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: stackAddr,
+ isValid: true,
+ },
+ {
+ name: "RemoteInvalidTarget",
+ senderAddr: remoteAddr,
+ senderLinkAddr: remoteLinkAddr,
+ targetAddr: unknownAddr,
+ isValid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Inject an incoming ARP request.
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), test.senderLinkAddr)
+ copy(h.ProtocolAddressSender(), test.senderAddr)
+ copy(h.ProtocolAddressTarget(), test.targetAddr)
+ c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
+
+ if !test.isValid {
+ // No packets should be sent after receiving an invalid ARP request.
+ // There is no need to perform a blocking read here, since packets are
+ // sent in the same function that handles ARP requests.
+ if pkt, ok := c.linkEP.Read(); ok {
+ t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto)
+ }
+ return
+ }
+
+ // Verify an ARP response was sent.
+ pi, ok := c.linkEP.Read()
+ if !ok {
+ t.Fatal("expected ARP response to be sent, got none")
+ }
+
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
+ }
+ rep := header.ARP(pi.Pkt.NetworkHeader().View())
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want)
+ }
+
+ // Verify the sender was saved in the neighbor cache.
+ wantEvent := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ addr: test.senderAddr,
+ linkAddr: tcpip.LinkAddress(test.senderLinkAddr),
+ state: stack.Stale,
+ }
+ if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
+ t.Fatal(err)
+ }
+
+ neighbors, err := c.s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("c.s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff)
+ }
+ t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ neigh, ok := neighborByAddr[test.senderAddr]
+ if !ok {
+ t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr)
+ }
+ if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
+ t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
+ }
+ if got, want := neigh.LocalAddr, stackAddr; got != want {
+ t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want)
+ }
+ if got, want := neigh.State, stack.Stale; got != want {
+ t.Errorf("got neighbor State = %s, want = %s", got, want)
+ }
+
+ // No more events should be dispatched
+ for {
+ event, ok := c.nudDisp.nextEvent()
+ if !ok {
+ break
+ }
+ t.Errorf("unexpected %s", event)
+ }
+ })
+ }
+}
+
+func TestLinkAddressRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: remoteLinkAddr,
+ expectLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: header.EthernetBroadcastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ p := arp.NewProtocol(nil)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index d1c728ccf..e247f06a4 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -41,5 +41,8 @@ go_test(
"reassembler_test.go",
],
library = ":fragmentation",
- deps = ["//pkg/tcpip/buffer"],
+ deps = [
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
+ ],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 2982450f8..e1909fab0 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -17,28 +17,58 @@
package fragmentation
import (
+ "errors"
"fmt"
"log"
"time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
-// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
-const DefaultReassembleTimeout = 30 * time.Second
+const (
+ // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+ DefaultReassembleTimeout = 30 * time.Second
-// HighFragThreshold is the threshold at which we start trimming old
-// fragmented packets. Linux uses a default value of 4 MB. See
-// net.ipv4.ipfrag_high_thresh for more information.
-const HighFragThreshold = 4 << 20 // 4MB
+ // HighFragThreshold is the threshold at which we start trimming old
+ // fragmented packets. Linux uses a default value of 4 MB. See
+ // net.ipv4.ipfrag_high_thresh for more information.
+ HighFragThreshold = 4 << 20 // 4MB
-// LowFragThreshold is the threshold we reach to when we start dropping
-// older fragmented packets. It's important that we keep enough room for newer
-// packets to be re-assembled. Hence, this needs to be lower than
-// HighFragThreshold enough. Linux uses a default value of 3 MB. See
-// net.ipv4.ipfrag_low_thresh for more information.
-const LowFragThreshold = 3 << 20 // 3MB
+ // LowFragThreshold is the threshold we reach to when we start dropping
+ // older fragmented packets. It's important that we keep enough room for newer
+ // packets to be re-assembled. Hence, this needs to be lower than
+ // HighFragThreshold enough. Linux uses a default value of 3 MB. See
+ // net.ipv4.ipfrag_low_thresh for more information.
+ LowFragThreshold = 3 << 20 // 3MB
+
+ // minBlockSize is the minimum block size for fragments.
+ minBlockSize = 1
+)
+
+var (
+ // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // provided.
+ ErrInvalidArgs = errors.New("invalid args")
+)
+
+// FragmentID is the identifier for a fragment.
+type FragmentID struct {
+ // Source is the source address of the fragment.
+ Source tcpip.Address
+
+ // Destination is the destination address of the fragment.
+ Destination tcpip.Address
+
+ // ID is the identification value of the fragment.
+ //
+ // This is a uint32 because IPv6 uses a 32-bit identification value.
+ ID uint32
+
+ // The protocol for the packet.
+ Protocol uint8
+}
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
@@ -46,14 +76,19 @@ type Fragmentation struct {
mu sync.Mutex
highLimit int
lowLimit int
- reassemblers map[uint32]*reassembler
+ reassemblers map[FragmentID]*reassembler
rList reassemblerList
size int
timeout time.Duration
+ blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
}
// NewFragmentation creates a new Fragmentation.
//
+// blockSize specifies the fragment block size, in bytes.
+//
// highMemoryLimit specifies the limit on the memory consumed
// by the fragments stored by Fragmentation (overhead of internal data-structures
// is not accounted). Fragments are dropped when the limit is reached.
@@ -64,7 +99,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -73,39 +108,81 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
lowMemoryLimit = 0
}
- return &Fragmentation{
- reassemblers: make(map[uint32]*reassembler),
+ if blockSize < minBlockSize {
+ blockSize = minBlockSize
+ }
+
+ f := &Fragmentation{
+ reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
+ blockSize: blockSize,
+ clock: clock,
}
+ f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
+
+ return f
}
// Process processes an incoming fragment belonging to an ID and returns a
-// complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+// complete packet and its protocol number when all the packets belonging to
+// that ID have been received.
+//
+// [first, last] is the range of the fragment bytes.
+//
+// first must be a multiple of the block size f is configured with. The size
+// of the fragment data must be a multiple of the block size, unless there are
+// no fragments following this fragment (more set to false).
+//
+// proto is the protocol number marked in the fragment being processed. It has
+// to be given here outside of the FragmentID struct because IPv6 should not use
+// the protocol to identify a fragment.
+func (f *Fragmentation) Process(
+ id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (
+ buffer.VectorisedView, uint8, bool, error) {
+ if first > last {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
+ }
+
+ if first%f.blockSize != 0 {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
+ }
+
+ fragmentSize := last - first + 1
+ if more && fragmentSize%f.blockSize != 0 {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
+ }
+
+ if l := vv.Size(); l < int(fragmentSize) {
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ }
+ vv.CapLength(int(fragmentSize))
+
f.mu.Lock()
r, ok := f.reassemblers[id]
- if ok && r.tooOld(f.timeout) {
- // This is very likely to be an id-collision or someone performing a slow-rate attack.
- f.release(r)
- ok = false
- }
if !ok {
- r = newReassembler(id)
+ r = newReassembler(id, f.clock)
f.reassemblers[id] = r
+ wasEmpty := f.rList.Empty()
f.rList.PushFront(r)
+ if wasEmpty {
+ // If we have just pushed a first reassembler into an empty list, we
+ // should kickstart the release job. The release job will keep
+ // rescheduling itself until the list becomes empty.
+ f.releaseReassemblersLocked()
+ }
}
f.mu.Unlock()
- res, done, consumed, err := r.process(first, last, more, vv)
+ res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
if err != nil {
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
f.release(r)
f.mu.Unlock()
- return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
f.size += consumed
@@ -124,7 +201,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
}
f.mu.Unlock()
- return res, done, nil
+ return res, firstFragmentProto, done, nil
}
func (f *Fragmentation) release(r *reassembler) {
@@ -142,3 +219,27 @@ func (f *Fragmentation) release(r *reassembler) {
f.size = 0
}
}
+
+// releaseReassemblersLocked releases already-expired reassemblers, then
+// schedules the job to call back itself for the remaining reassemblers if
+// any. This function must be called with f.mu locked.
+func (f *Fragmentation) releaseReassemblersLocked() {
+ now := f.clock.NowMonotonic()
+ for {
+ // The reassembler at the end of the list is the oldest.
+ r := f.rList.Back()
+ if r == nil {
+ // The list is empty.
+ break
+ }
+ elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ if f.timeout > elapsed {
+ // If the oldest reassembler has not expired, schedule the release
+ // job so that this function is called back when it has expired.
+ f.releaseJob.Schedule(f.timeout - elapsed)
+ break
+ }
+ // If the oldest reassembler has already expired, release it.
+ f.release(r)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 72c0f53be..189b223c5 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -15,11 +15,13 @@
package fragmentation
import (
+ "errors"
"reflect"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
// vv is a helper to build VectorisedView from different strings.
@@ -33,16 +35,18 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
}
type processInput struct {
- id uint32
+ id FragmentID
first uint16
last uint16
more bool
+ proto uint8
vv buffer.VectorisedView
}
type processOutput struct {
- vv buffer.VectorisedView
- done bool
+ vv buffer.VectorisedView
+ proto uint8
+ done bool
}
var processTestCases = []struct {
@@ -53,8 +57,8 @@ var processTestCases = []struct {
{
comment: "One ID",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -62,12 +66,23 @@ var processTestCases = []struct {
},
},
{
+ comment: "Next Header protocol mismatch",
+ in: []processInput{
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: buffer.VectorisedView{}, done: false},
+ {vv: vv(4, "01", "23"), proto: 6, done: true},
+ },
+ },
+ {
comment: "Two IDs",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -81,19 +96,27 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{})
+ firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
if err != nil {
- t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
+ in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
}
if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView())
}
if done != c.out[i].done {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
+ in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
}
if c.out[i].done {
+ if firstFragmentProto != proto {
+ t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
+ in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
+ }
if _, ok := f.reassemblers[in.id]; ok {
t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
}
@@ -109,53 +132,154 @@ func TestFragmentationProcess(t *testing.T) {
}
func TestReassemblingTimeout(t *testing.T) {
- timeout := time.Millisecond
- f := NewFragmentation(1024, 512, timeout)
- // Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(0, 0, 0, true, vv(1, "0"))
- // Sleep more than the timeout.
- time.Sleep(2 * timeout)
- // Send another fragment that completes a packet.
- // However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
- if err != nil {
- t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ const (
+ reassemblyTimeout = time.Millisecond
+ protocol = 0xff
+ )
+
+ type fragment struct {
+ first uint16
+ last uint16
+ more bool
+ data string
+ }
+
+ type event struct {
+ // name is a nickname of this event.
+ name string
+
+ // clockAdvance is a duration to advance the clock. The clock advances
+ // before a fragment specified in the fragment field is processed.
+ clockAdvance time.Duration
+
+ // fragment is a fragment to process. This can be nil if there is no
+ // fragment to process.
+ fragment *fragment
+
+ // expectDone is true if the fragmentation instance should report the
+ // reassembly is done after the fragment is processd.
+ expectDone bool
+
+ // sizeAfterEvent is the expected size of the fragmentation instance after
+ // the event.
+ sizeAfterEvent int
}
- if done {
- t.Errorf("Fragmentation does not respect the reassembling timeout.")
+
+ half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
+ half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
+
+ tests := []struct {
+ name string
+ events []event
+ }{
+ {
+ name: "half1 and half2 are reassembled successfully",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: true,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ {
+ name: "half1 timeout, half2 timeout",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ for _, event := range test.events {
+ clock.Advance(event.clockAdvance)
+ if frag := event.fragment; frag != nil {
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data))
+ if err != nil {
+ t.Fatalf("%s: f.Process failed: %s", event.name, err)
+ }
+ if done != event.expectDone {
+ t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
+ }
+ }
+ if got, want := f.size, event.sizeAfterEvent; got != want {
+ t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want)
+ }
+ }
+ })
}
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
// Send first fragment with id = 1.
- f.Process(1, 0, 0, true, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"))
// Send first fragment with id = 2.
- f.Process(2, 0, 0, true, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(3, 0, 0, true, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"))
- if _, ok := f.reassemblers[0]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
}
- if _, ok := f.reassemblers[1]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
}
- if _, ok := f.reassemblers[3]; !ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
}
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Send the same packet again.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
got := f.size
want := 1
@@ -163,3 +287,97 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}
+
+func TestErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize uint16
+ first uint16
+ last uint16
+ more bool
+ data string
+ err error
+ }{
+ {
+ name: "exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: false,
+ data: "01",
+ },
+ {
+ name: "exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "01",
+ },
+ {
+ name: "exact block size with more and extra data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "012",
+ },
+ {
+ name: "exact block size with more and too little data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: false,
+ data: "0",
+ },
+ {
+ name: "first not a multiple of block size",
+ blockSize: 2,
+ first: 3,
+ last: 4,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "first more than last",
+ blockSize: 2,
+ first: 4,
+ last: 3,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{})
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
+ if !errors.Is(err, test.err) {
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
+ }
+ if done {
+ t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 0a83d81f2..9bb051a30 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -32,23 +32,23 @@ type hole struct {
type reassembler struct {
reassemblerEntry
- id uint32
+ id FragmentID
size int
+ proto uint8
mu sync.Mutex
holes []hole
deleted int
heap fragHeap
done bool
- creationTime time.Time
+ creationTime int64
}
-func newReassembler(id uint32) *reassembler {
+func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
- deleted: 0,
heap: make(fragHeap, 0, 8),
- creationTime: time.Now(),
+ creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
@@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
+func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -86,7 +86,18 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, consumed, nil
+ }
+ // For IPv6, it is possible to have different Protocol values between
+ // fragments of a packet (because, unlike IPv4, the Protocol is not used to
+ // identify a fragment). In this case, only the Protocol of the first
+ // fragment must be used as per RFC 8200 Section 4.5.
+ //
+ // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded
+ // here (instead of just the protocol) because most IP options should be
+ // derived from the first fragment.
+ if first == 0 {
+ r.proto = proto
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
@@ -96,17 +107,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.VectorisedView{}, false, consumed, nil
+ return buffer.VectorisedView{}, 0, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
+ return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err)
}
- return res, true, consumed, nil
-}
-
-func (r *reassembler) tooOld(timeout time.Duration) bool {
- return time.Now().Sub(r.creationTime) > timeout
+ return res, r.proto, true, consumed, nil
}
func (r *reassembler) checkDoneOrMark() bool {
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index 7eee0710d..a0a04a027 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -18,6 +18,8 @@ import (
"math"
"reflect"
"testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
type updateHolesInput struct {
@@ -94,7 +96,7 @@ var holesTestCases = []struct {
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
- r := newReassembler(0)
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 7c8fb3e0a..6861cfdaf 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -17,32 +17,44 @@ package ip_test
import (
"testing"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
- localIpv4Addr = "\x0a\x00\x00\x01"
- localIpv4PrefixLen = 24
- remoteIpv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- localIpv6PrefixLen = 120
- remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ localIPv4Addr = "\x0a\x00\x00\x01"
+ remoteIPv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ nicID = 1
)
+var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv4Addr,
+ PrefixLen: 24,
+}
+
+var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv6Addr,
+ PrefixLen: 120,
+}
+
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
// The former is used to pretend that it's a link endpoint so that we can
// inspect packets written by the network endpoints. The latter is used to
@@ -96,9 +108,10 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) {
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
+ return stack.TransportPacketHandled
}
// DeliverTransportControlPacket is called by network endpoints after parsing
@@ -156,13 +169,13 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
var dstAddr tcpip.Address
if t.v4 {
- h := header.IPv4(pkt.Header.View())
+ h := header.IPv4(pkt.NetworkHeader().View())
prot = tcpip.TransportProtocolNumber(h.Protocol())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
} else {
- h := header.IPv6(pkt.Header.View())
+ h := header.IPv6(pkt.NetworkHeader().View())
prot = tcpip.TransportProtocolNumber(h.NextHeader())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
@@ -172,60 +185,345 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
-func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testObject) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv4.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv6.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStack() *stack.Stack {
- return stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
+ e := channel.New(0, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ }
+
+ v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ }
+
+ return s, e
+}
+
+func buildDummyStack(t *testing.T) *stack.Stack {
+ t.Helper()
+
+ s, _ := buildDummyStackWithLinkEndpoint(t)
+ return s
+}
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ tester testObject
+
+ mu struct {
+ sync.RWMutex
+ disabled bool
+ }
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (t *testInterface) Enabled() bool {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return !t.mu.disabled
+}
+
+func (t *testInterface) setEnabled(v bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mu.disabled = !v
+}
+
+func (t *testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return &t.tester
+}
+
+func TestSourceAddressValidation(t *testing.T) {
+ rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv4Addr,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ srcAddress tcpip.Address
+ rxICMP func(*channel.Endpoint, tcpip.Address)
+ valid bool
+ }{
+ {
+ name: "IPv4 valid",
+ srcAddress: "\x01\x02\x03\x04",
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 valid",
+ srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddress: "\xe0\x00\x00\x01",
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ rxICMP: rxIPv6ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 broadcast",
+ srcAddress: header.IPv4Broadcast,
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 subnet broadcast",
+ srcAddress: func() tcpip.Address {
+ subnet := localIPv4AddrWithPrefix.Subnet()
+ return subnet.Broadcast()
+ }(),
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := buildDummyStackWithLinkEndpoint(t)
+ test.rxICMP(e, test.srcAddress)
+
+ var wantValid uint64
+ if test.valid {
+ wantValid = 1
+ }
+
+ if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
+ t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
+ t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
+ }
+ })
+ }
+}
+
+func TestEnableWhenNICDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protocolFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protocolFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var nic testInterface
+ nic.setEnabled(false)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
+ })
+ p := s.NetworkProtocolInstance(test.protoNum)
+
+ // We pass nil for all parameters except the NetworkInterface and Stack
+ // since Enable only depends on these.
+ ep := p.NewEndpoint(&nic, nil, nil, nil)
+
+ // The endpoint should initially be disabled, regardless the NIC's enabled
+ // status.
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Attempting to enable the endpoint while the NIC is disabled should
+ // fail.
+ nic.setEnabled(false)
+ if err := ep.Enable(); err != tcpip.ErrNotPermitted {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ }
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status so we "enable" the NIC to read just the endpoint's
+ // enabled status.
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Enabling the interface after the NIC has been enabled should succeed.
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ if !ep.Enabled() {
+ t.Fatal("got ep.Enabled() = false, want = true")
+ }
+
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status.
+ nic.setEnabled(false)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Disabling the endpoint when the NIC is enabled should make the endpoint
+ // disabled.
+ nic.setEnabled(true)
+ ep.Disable()
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ })
+ }
}
func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
}
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
+ defer ep.Close()
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
@@ -233,16 +531,19 @@ func TestIPv4Send(t *testing.T) {
payload[i] = uint8(i)
}
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+ // Setup the packet buffer.
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ep.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ })
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIPv4Addr
+ nic.tester.dstAddr = remoteIPv4Addr
+ nic.tester.contents = payload
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -250,20 +551,25 @@ func TestIPv4Send(t *testing.T) {
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
- }, &stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- }); err != nil {
+ }, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
totalLen := header.IPv4MinimumSize + 30
@@ -274,8 +580,8 @@ func TestIPv4Receive(t *testing.T) {
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
@@ -284,20 +590,24 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv4Addr
+ nic.tester.dstAddr = localIPv4Addr
+ nic.tester.contents = view[header.IPv4MinimumSize:totalLen]
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
- proto.Parse(&pkt)
- ep.HandlePacket(&r, &pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -307,7 +617,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
name string
expectedCount int
fragmentOffset uint16
- code uint8
+ code header.ICMPv4Code
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
@@ -321,20 +631,26 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
+ r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
if err != nil {
t.Fatal(err)
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
}
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
view := buffer.NewView(dataOffset + 8)
@@ -346,7 +662,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIpv4Addr,
+ DstAddr: localIPv4Addr,
})
// Create the ICMP header.
@@ -364,8 +680,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: c.fragmentOffset,
- SrcAddr: localIpv4Addr,
- DstAddr: remoteIpv4Addr,
+ SrcAddr: localIPv4Addr,
+ DstAddr: remoteIPv4Addr,
})
// Make payload be non-zero.
@@ -375,27 +691,35 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv4Addr
+ nic.tester.dstAddr = localIPv4Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
totalLen := header.IPv4MinimumSize + 24
@@ -409,8 +733,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
Protocol: 10,
FragmentOffset: 0,
Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -425,8 +749,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: 24,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -434,39 +758,54 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv4Addr
+ nic.tester.dstAddr = localIPv4Addr
+ nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
// Send first segment.
- pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()}
- proto.Parse(&pkt)
- ep.HandlePacket(&r, &pkt)
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: frag1.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.tester.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls)
}
// Send second segment.
- pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()}
- proto.Parse(&pkt)
- ep.HandlePacket(&r, &pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: frag2.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
// Allocate and initialize the payload view.
@@ -475,16 +814,19 @@ func TestIPv6Send(t *testing.T) {
payload[i] = uint8(i)
}
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+ // Setup the packet buffer.
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ep.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ })
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIPv6Addr
+ nic.tester.dstAddr = remoteIPv6Addr
+ nic.tester.contents = payload
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -492,20 +834,24 @@ func TestIPv6Send(t *testing.T) {
Protocol: 123,
TTL: 123,
TOS: stack.DefaultTOS,
- }, &stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- }); err != nil {
+ }, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
totalLen := header.IPv6MinimumSize + 30
@@ -515,8 +861,8 @@ func TestIPv6Receive(t *testing.T) {
PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
NextHeader: 10,
HopLimit: 20,
- SrcAddr: remoteIpv6Addr,
- DstAddr: localIpv6Addr,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -525,21 +871,25 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv6Addr
+ nic.tester.dstAddr = localIPv6Addr
+ nic.tester.contents = view[header.IPv6MinimumSize:totalLen]
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- pkt := stack.PacketBuffer{Data: view.ToVectorisedView()}
- proto.Parse(&pkt)
- ep.HandlePacket(&r, &pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -553,7 +903,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
expectedCount int
fragmentOffset *uint16
typ header.ICMPv6Type
- code uint8
+ code header.ICMPv6Code
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
@@ -570,7 +920,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
r, err := buildIPv6Route(
- localIpv6Addr,
+ localIPv6Addr,
"\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
)
if err != nil {
@@ -578,15 +928,20 @@ func TestIPv6ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ s := buildDummyStack(t)
+ proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
}
-
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
@@ -600,7 +955,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
SrcAddr: outerSrcAddr,
- DstAddr: localIpv6Addr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -616,8 +971,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
PayloadLength: 100,
NextHeader: 10,
HopLimit: 20,
- SrcAddr: localIpv6Addr,
- DstAddr: remoteIpv6Addr,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
})
// Build the fragmentation header if needed.
@@ -639,19 +994,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIPv6Addr
+ nic.tester.dstAddr = localIPv6Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
// Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
@@ -663,11 +1018,9 @@ func TestIPv6ReceiveControl(t *testing.T) {
// becomes Data.
func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
v := view[:len(view)-trunc]
- if len(v) < netHdrLen {
- return &stack.PacketBuffer{Data: v.ToVectorisedView()}
- }
- return &stack.PacketBuffer{
- NetworkHeader: v[:netHdrLen],
- Data: v[netHdrLen:].ToVectorisedView(),
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v.ToVectorisedView(),
+ })
+ _, _ = pkt.NetworkHeader().Consume(netHdrLen)
+ return pkt
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 78420d6e6..ee2c23e91 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -10,9 +10,11 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
@@ -26,14 +28,17 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 1b67aa066..eab9a530c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,9 @@
package ipv4
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -37,8 +40,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// false.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
- // original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -53,7 +57,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
@@ -75,45 +79,92 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Echo.Increment()
// Only send a reply if the checksum is valid.
- wantChecksum := h.Checksum()
- // Reset the checksum field to 0 to can calculate the proper
- // checksum. We'll have to reset this before we hand the packet
- // off.
+ headerChecksum := h.Checksum()
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- if gotChecksum != wantChecksum {
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
+ calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(headerChecksum)
+ if calculatedChecksum != headerChecksum {
+ // It's possible that a raw socket still expects to receive this.
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{
- Data: pkt.Data.Clone(nil),
- NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
+ // DeliverTransportPacket will take ownership of pkt so don't use it beyond
+ // this point. Make a deep copy of the data before pkt gets sent as we will
+ // be modifying fields.
+ //
+ // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
+ // waiting endpoints. Consider moving responsibility for doing the copy to
+ // DeliverTransportPacket so that is is only done when needed.
+ replyData := pkt.Data.ToOwnedView()
+ replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
+
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+
+ remoteLinkAddr := r.RemoteLinkAddress
+
+ // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
+ // source address MUST be one of its own IP addresses (but not a broadcast
+ // or multicast address).
+ localAddr := r.LocalAddress
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) {
+ localAddr = ""
+ }
+
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ // Use the remote link address from the incoming packet.
+ r.ResolveWith(remoteLinkAddr)
+
+ // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
+ // header information, we may have to change this code to handle the
+ // ICMP header no longer being in the data buffer.
+
+ // Because IP and ICMP are so closely intertwined, we need to handcraft our
+ // IP header to be able to follow RFC 792. The wording on page 13 is as
+ // follows:
+ // IP Fields:
+ // Addresses
+ // The address of the source in an echo message will be the
+ // destination of the echo reply message. To form an echo reply
+ // message, the source and destination addresses are simply reversed,
+ // the type code changed to 0, and the checksum recomputed.
+ //
+ // This was interpreted by early implementors to mean that all options must
+ // be copied from the echo request IP header to the echo reply IP header
+ // and this behaviour is still relied upon by some applications.
+ //
+ // Create a copy of the IP header we received, options and all, and change
+ // The fields we need to alter.
+ //
+ // We need to produce the entire packet in the data segment in order to
+ // use WriteHeaderIncludedPacket().
+ replyIPHdr.SetSourceAddress(r.LocalAddress)
+ replyIPHdr.SetDestinationAddress(r.RemoteAddress)
+ replyIPHdr.SetTTL(r.DefaultTTL())
+
+ replyICMPHdr := header.ICMPv4(replyData)
+ replyICMPHdr.SetType(header.ICMPv4EchoReply)
+ replyICMPHdr.SetChecksum(0)
+ replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0))
+
+ replyVV := buffer.View(replyIPHdr).ToVectorisedView()
+ replyVV.AppendView(replyData)
+ replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: replyVV,
})
+ replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- vv := pkt.Data.Clone(nil)
- vv.TrimFront(header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- copy(pkt, h)
- pkt.SetType(header.ICMPv4EchoReply)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
+ // The checksum will be calculated so we don't need to do it here.
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
- Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, &stack.PacketBuffer{
- Header: hdr,
- Data: vv,
- TransportHeader: buffer.View(pkt),
- }); err != nil {
+ if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -129,6 +180,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
@@ -165,3 +219,177 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Invalid.Increment()
}
}
+
+// ======= ICMP Error packet generation =========
+
+// icmpReason is a marker interface for IPv4 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// icmpReasonProtoUnreachable is an error where the transport protocol is
+// not supported.
+type icmpReasonProtoUnreachable struct{}
+
+func (*icmpReasonProtoUnreachable) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv4 and sends it back to the remote device that sent
+// the problematic packet. It incorporates as much of that packet as
+// possible as well as any error metadata as is available. returnError
+// expects pkt to hold a valid IPv4 packet as per the wire format.
+func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ sent := r.Stats().ICMP.V4PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ // We check we are responding only when we are allowed to.
+ // See RFC 1812 section 4.3.2.7 (shown below).
+ //
+ // =========
+ // 4.3.2.7 When Not to Send ICMP Errors
+ //
+ // An ICMP error message MUST NOT be sent as the result of receiving:
+ //
+ // o An ICMP error message, or
+ //
+ // o A packet which fails the IP header validation tests described in
+ // Section [5.2.2] (except where that section specifically permits
+ // the sending of an ICMP error message), or
+ //
+ // o A packet destined to an IP broadcast or IP multicast address, or
+ //
+ // o A packet sent as a Link Layer broadcast or multicast, or
+ //
+ // o Any fragment of a datagram other then the first fragment (i.e., a
+ // packet for which the fragment offset in the IP header is nonzero).
+ //
+ // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
+ // response to a non-initial fragment, but it currently can not happen.
+
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any {
+ return nil
+ }
+
+ networkHeader := pkt.NetworkHeader().View()
+ transportHeader := pkt.TransportHeader().View()
+
+ // Don't respond to icmp error packets.
+ if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) {
+ // TODO(gvisor.dev/issue/3810):
+ // Unfortunately the current stack pretty much always has ICMPv4 headers
+ // in the Data section of the packet but there is no guarantee that is the
+ // case. If this is the case grab the header to make it like all other
+ // packet types. When this is cleaned up the Consume should be removed.
+ if transportHeader.IsEmpty() {
+ var ok bool
+ transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize)
+ if !ok {
+ return nil
+ }
+ } else if transportHeader.Size() < header.ICMPv4MinimumSize {
+ return nil
+ }
+ // We need to decide to explicitly name the packets we can respond to or
+ // the ones we can not respond to. The decision is somewhat arbitrary and
+ // if problems arise this could be reversed. It was judged less of a breach
+ // of protocol to not respond to unknown non-error packets than to respond
+ // to unknown error packets so we take the first approach.
+ switch header.ICMPv4(transportHeader).Type() {
+ case
+ header.ICMPv4EchoReply,
+ header.ICMPv4Echo,
+ header.ICMPv4Timestamp,
+ header.ICMPv4TimestampReply,
+ header.ICMPv4InfoRequest,
+ header.ICMPv4InfoReply:
+ default:
+ // Assume any type we don't know about may be an error type.
+ return nil
+ }
+ }
+
+ // Now work out how much of the triggering packet we should return.
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes.
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 and RFC 792 where it mentioned that at
+ // least 8 bytes of the payload must be included. Today linux and other
+ // systems implement the RFC 1812 definition and not the original
+ // requirement. We treat 8 bytes as the minimum but will try send more.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+
+ if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
+ return nil
+ }
+
+ payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, an AF_RAW or AF_PACKET socket may use what the transport
+ // protocol considers an unreachable destination. Thus we deep copy pkt to
+ // prevent multiple ownership and SR errors. The new copy is a vectorized
+ // view with the entire incoming IP packet reassembled and truncated as
+ // required. This is now the payload of the new ICMP packet and no longer
+ // considered a packet in its own right.
+ newHeader := append(buffer.View(nil), networkHeader...)
+ newHeader = append(newHeader, transportHeader...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
+ payload.CapLength(payloadLen)
+
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+
+ icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
+
+ icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ switch reason.(type) {
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ case *icmpReasonProtoUnreachable:
+ icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ counter := sent.DstUnreachable
+
+ if err := r.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ icmpPkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 7e9f16c90..a2be64fb8 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -12,21 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ipv4 contains the implementation of the ipv4 network protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.NewProtocol() as one of the network
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// ipv4.ProtocolNumber as the network protocol number when calling
-// Stack.NewEndpoint().
+// Package ipv4 contains the implementation of the ipv4 network protocol.
package ipv4
import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -45,70 +42,139 @@ const (
// buckets is the number of identifier buckets.
buckets = 2048
+
+ // The size of a fragment block, in bytes, as per RFC 791 section 3.1,
+ // page 14.
+ fragmentblockSize = 8
)
+var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
+
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
- linkEP stack.LinkEndpoint
- dispatcher stack.TransportDispatcher
- fragmentation *fragmentation.Fragmentation
- protocol *protocol
- stack *stack.Stack
+ nic stack.NetworkInterface
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ protocol *protocol
+
+ // enabled is set to 1 when the enpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ }
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
- nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
- dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
- protocol: p,
- stack: st,
+ nic: nic,
+ linkEP: nic.LinkEndpoint(),
+ dispatcher: dispatcher,
+ protocol: p,
}
+ e.mu.addressableEndpointState.Init(e)
+ return e
+}
- return e, nil
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Create an endpoint to receive broadcast packets on this interface.
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ return err
+ }
+ // We have no need for the address endpoint.
+ ep.DecRef()
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
+ return err
}
-// DefaultTTL is the default time-to-live value for this endpoint.
-func (e *endpoint) DefaultTTL() uint8 {
- return e.protocol.DefaultTTL()
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
-func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
}
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
+ }
+
+ // The address may have already been removed.
+ if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
+ }
}
-// ID returns the ipv4 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
+// DefaultTTL is the default time-to-live value for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
}
-// PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
+ return e.linkEP.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
// GSOMaxSize returns the maximum GSO packet size.
@@ -125,14 +191,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is entirely in pkt.Header but does not
-// assume that only the IP header is in pkt.Header. It assumes that the input
-// packet's stated length matches the length of the header+payload. mtu
-// includes the IP header and options. This does not support the DontFragment
-// IP flag.
+// write. It assumes that the IP header is already present in pkt.NetworkHeader.
+// pkt.TransportHeader may be set. mtu includes the IP header and options. This
+// does not support the DontFragment IP flag.
func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
- ip := header.IPv4(pkt.Header.View())
+ ip := header.IPv4(pkt.NetworkHeader().View())
flags := ip.Flags()
// Update mtu to take into account the header, which will exist in all
@@ -146,91 +210,89 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
outerMTU := innerMTU + int(ip.HeaderLength())
offset := ip.FragmentOffset()
- originalAvailableLength := pkt.Header.AvailableLength()
+
+ // Keep the length reserved for link-layer, we need to create fragments with
+ // the same reserved length.
+ reservedForLink := pkt.AvailableHeaderBytes()
+
+ // Destroy the packet, pull all payloads out for fragmentation.
+ transHeader, data := pkt.TransportHeader().View(), pkt.Data
+
+ // Where possible, the first fragment that is sent has the same
+ // number of bytes reserved for header as the input packet. The link-layer
+ // endpoint may depend on this for looking at, eg, L4 headers.
+ transFitsFirst := len(transHeader) <= innerMTU
+
for i := 0; i < n; i++ {
- // Where possible, the first fragment that is sent has the same
- // pkt.Header.UsedLength() as the input packet. The link-layer
- // endpoint may depend on this for looking at, eg, L4 headers.
- h := ip
- if i > 0 {
- pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
- h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength())))
- copy(h, ip[:ip.HeaderLength()])
+ reserve := reservedForLink + int(ip.HeaderLength())
+ if i == 0 && transFitsFirst {
+ // Reserve for transport header if it's going to be put in the first
+ // fragment.
+ reserve += len(transHeader)
}
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: reserve,
+ })
+ fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+
+ // Copy data for the fragment.
+ avail := innerMTU
+
+ if n := len(transHeader); n > 0 {
+ if n > avail {
+ n = avail
+ }
+ if i == 0 && transFitsFirst {
+ copy(fragPkt.TransportHeader().Push(n), transHeader)
+ } else {
+ fragPkt.Data.AppendView(transHeader[:n:n])
+ }
+ transHeader = transHeader[n:]
+ avail -= n
+ }
+
+ if avail > 0 {
+ n := data.Size()
+ if n > avail {
+ n = avail
+ }
+ data.ReadToVV(&fragPkt.Data, n)
+ avail -= n
+ }
+
+ copied := uint16(innerMTU - avail)
+
+ // Set lengths in header and calculate checksum.
+ h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip)))
+ copy(h, ip)
if i != n-1 {
h.SetTotalLength(uint16(outerMTU))
h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
} else {
- h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size()))
+ h.SetTotalLength(uint16(h.HeaderLength()) + copied)
h.SetFlagsFragmentOffset(flags, offset)
}
h.SetChecksum(0)
h.SetChecksum(^h.CalculateChecksum())
- offset += uint16(innerMTU)
- if i > 0 {
- newPayload := pkt.Data.Clone(nil)
- newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
- Header: pkt.Header,
- Data: newPayload,
- NetworkHeader: buffer.View(h),
- }); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- pkt.Data.TrimFront(newPayload.Size())
- continue
- }
- // Special handling for the first fragment because it comes
- // from the header.
- if outerMTU >= pkt.Header.UsedLength() {
- // This fragment can fit all of pkt.Header and possibly
- // some of pkt.Data, too.
- newPayload := pkt.Data.Clone(nil)
- newPayloadLength := outerMTU - pkt.Header.UsedLength()
- newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
- Header: pkt.Header,
- Data: newPayload,
- NetworkHeader: buffer.View(h),
- }); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- pkt.Data.TrimFront(newPayloadLength)
- } else {
- // The fragment is too small to fit all of pkt.Header.
- startOfHdr := pkt.Header
- startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
- emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
- Header: startOfHdr,
- Data: emptyVV,
- NetworkHeader: buffer.View(h),
- }); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- // Add the unused bytes of pkt.Header into the pkt.Data
- // that remains to be sent.
- restOfHdr := pkt.Header.View()[outerMTU:]
- tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
- tmp.Append(pkt.Data)
- pkt.Data = tmp
+ offset += copied
+
+ // Send out the fragment.
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(n - i))
+ return err
}
+ r.Stats().IP.PacketsSent.Increment()
}
return nil
}
-func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + payloadSize)
- id := uint32(0)
- if length > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
- }
+func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+ ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+ length := uint16(pkt.Size())
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
+ // datagrams. Since the DF bit is never being set here, all datagrams
+ // are non-atomic and need an ID.
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
@@ -242,31 +304,33 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- return ip
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
- pkt.NetworkHeader = buffer.View(ip)
+ e.addIPHeader(r, pkt, params)
- nicName := e.stack.FindNICNameFromID(e.NICID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
- // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
- // only NATted packets, but removing this check short circuits broadcasts
- // before they are sent out to other hosts.
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
- netHeader := header.IPv4(pkt.NetworkHeader)
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -282,10 +346,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
}
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -302,25 +367,28 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pkt := pkts.Front(); pkt != nil; {
- ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
- pkt.NetworkHeader = buffer.View(ip)
+ e.addIPHeader(r, pkt, params)
pkt = pkt.Next()
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- // Slow Path as we are dropping some packets in the batch degrade to
+ // Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
@@ -328,8 +396,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv4(pkt.NetworkHeader)
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -340,12 +408,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
}
n++
}
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, nil
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -376,13 +448,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// Set the packet ID when zero.
if ip.ID() == 0 {
- id := uint32(0)
- if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for
+ // non-atomic datagrams, so assign an ID to all such datagrams
+ // according to the definition given in RFC 6864 section 4.
+ if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
- ip.SetID(uint16(id))
}
// Always set the checksum.
@@ -396,33 +467,47 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return nil
}
+ if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
-
- ip = ip[:ip.HeaderLength()]
- pkt.Header = buffer.NewPrependableFromView(buffer.View(ip))
- pkt.Data.TrimFront(int(ip.HeaderLength()))
- return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
- h := header.IPv4(pkt.NetworkHeader)
- if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
+ if !e.isEnabled() {
+ return
+ }
+
+ h := header.IPv4(pkt.NetworkHeader().View())
+ if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 1122 section 3.2.1.3:
+ // When a host sends any datagram, the IP source address MUST
+ // be one of its own IP addresses (but not a broadcast or
+ // multicast address).
+ if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
return
}
if h.More() || h.FragmentOffset() != 0 {
- if pkt.Data.Size()+len(pkt.TransportHeader) == 0 {
+ if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -430,18 +515,37 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < h.FragmentOffset() {
+ start := h.FragmentOffset()
+ // Drop the fragment if the size of the reassembled payload would exceed the
+ // maximum payload size.
+ //
+ // Note that this addition doesn't overflow even on 32bit architecture
+ // because pkt.Data.Size() should not exceed 65535 (the max IP datagram
+ // size). Otherwise the packet would've been rejected as invalid before
+ // reaching here.
+ if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
var ready bool
var err error
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data)
+ proto := h.Protocol()
+ pkt.Data, _, ready, err = e.protocol.fragmentation.Process(
+ // As per RFC 791 section 2.3, the identification value is unique
+ // for a source-destination pair and protocol.
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: uint32(h.ID()),
+ Protocol: proto,
+ },
+ start,
+ start+uint16(pkt.Data.Size())-1,
+ h.More(),
+ proto,
+ pkt.Data,
+ )
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -451,27 +555,166 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
}
+
+ r.Stats().IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
- pkt.NetworkHeader.CapLength(int(h.HeaderLength()))
+ // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
+ // headers, the setting of the transport number here should be
+ // unnecessary and removed.
+ pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt)
return
}
- r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
+
+ switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ case stack.TransportPacketHandled:
+ case stack.TransportPacketDestinationPortUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1
+ // A host SHOULD generate Destination Unreachable messages with code:
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported
+ _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
+ }
}
// Close cleans up resources associated with the endpoint.
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.disableLocked()
+ e.mu.addressableEndpointState.Cleanup()
+}
+
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ loopback := e.nic.IsLoopback()
+ addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ // IPv4 has a notion of a subnet broadcast address and considers the
+ // loopback interface bound to an address's whole subnet (on linux).
+ return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
+ })
+ if addressEndpoint != nil {
+ return addressEndpoint
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
+ if err != nil {
+ // AddAddress only returns an error if the address is already assigned,
+ // but we just checked above if the address exists so we expect no error.
+ panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
+ }
+ return addressEndpoint
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV4MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
- ids []uint32
- hashIV uint32
+ stack *stack.Stack
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
defaultTTL uint32
+
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
+ ids []uint32
+ hashIV uint32
+
+ fragmentation *fragmentation.Fragmentation
}
// Number returns the ipv4 protocol number.
@@ -496,10 +739,10 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case tcpip.DefaultTTLOption:
- p.SetDefaultTTL(uint8(v))
+ case *tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(*v))
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -507,7 +750,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
@@ -533,33 +776,28 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return 0, false, false
}
- ipHdr := header.IPv4(hdr)
- // If there are options, pull those into hdr as well.
- if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() {
- hdr, ok = pkt.Data.PullUp(headerLen)
- if !ok {
- panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen))
- }
- ipHdr = header.IPv4(hdr)
- }
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
+}
- // If this is a fragment, don't bother parsing the transport header.
- parseTransportHeader := true
- if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
- parseTransportHeader = false
- }
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
- pkt.NetworkHeader = hdr
- pkt.Data.TrimFront(len(hdr))
- pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
- return ipHdr.TransportProtocol(), parseTransportHeader, true
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ if v {
+ atomic.StoreUint32(&p.forwarding, 1)
+ } else {
+ atomic.StoreUint32(&p.forwarding, 0)
+ }
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -583,7 +821,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol() stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -593,5 +831,11 @@ func NewProtocol() stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
+ return &protocol{
+ stack: s,
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+ }
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 11e579c4b..712fbb861 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,17 +17,21 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
- "math/rand"
+ "math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -35,8 +39,8 @@ import (
func TestExcludeBroadcast(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
const defaultMTU = 65536
@@ -91,25 +95,274 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
- hdr := buffer.NewPrependable(hdrLength + extraLength)
- hdr.Prepend(hdrLength)
- rand.Read(hdr.View())
-
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
+// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
+// checks the response.
+func TestIPv4Sanity(t *testing.T) {
+ const (
+ defaultMTU = header.IPv6MinimumMTU
+ ttl = 255
+ nicID = 1
+ randomSequence = 123
+ randomIdent = 42
+ )
+ var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ )
+
+ tests := []struct {
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ shouldFail bool
+ expectICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ options []byte
+ }{
+ {
+ name: "valid",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ },
+ // The TTL tests check that we are not rejecting an incoming packet
+ // with a zero or one TTL, which has been a point of confusion in the
+ // past as RFC 791 says: "If this field contains the value zero, then the
+ // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
+ // for the case of the destination host, stating as follows.
+ //
+ // A host MUST NOT send a datagram with a Time-to-Live (TTL)
+ // value of zero.
+ //
+ // A host MUST NOT discard a datagram just because it was
+ // received with TTL less than 2.
+ {
+ name: "zero TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 0,
+ shouldFail: false,
+ },
+ {
+ name: "one TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ shouldFail: false,
+ },
+ {
+ name: "End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "NOP options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 1, 1},
+ },
+ {
+ name: "NOP and End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 0, 0},
+ },
+ {
+ name: "bad header length",
+ headerLength: header.IPv4MinimumSize - 1,
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (0)",
+ maxTotalLength: 0,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip + icmp - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad protocol",
+ maxTotalLength: defaultMTU,
+ transportProtocol: 99,
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: true,
+ ICMPType: header.ICMPv4DstUnreachable,
+ ICMPCode: header.ICMPv4ProtoUnreachable,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Round up the header size to the next multiple of 4 as RFC 791, page 11
+ // says: "Internet Header Length is the length of the internet header
+ // in 32 bit words..." and on page 23: "The internet header padding is
+ // used to ensure that the internet header ends on a 32 bit boundary."
+ ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1)
+
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+
+ // Specify ident/seq to make sure we get the same in the response.
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ if test.maxTotalLength < totalLen {
+ totalLen = test.maxTotalLength
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHeaderLength),
+ TotalLength: totalLen,
+ Protocol: test.transportProtocol,
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: ipv4Addr.Address,
+ })
+ if n := copy(ip.Options(), test.options); n != len(test.options) {
+ t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options))
+ }
+ // Override the correct value if the test case specified one.
+ if test.headerLength != 0 {
+ ip.SetHeaderLength(test.headerLength)
+ }
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e.Read()
+ if !ok {
+ if test.shouldFail {
+ if test.expectICMP {
+ t.Fatal("expected ICMP error response missing")
+ }
+ return // Expected silent failure.
+ }
+ t.Fatal("expected ICMP echo reply missing")
+ }
+
+ // Check the route that brought the packet to us.
+ if reply.Route.LocalAddress != ipv4Addr.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
+ }
+ if reply.Route.RemoteAddress != remoteIPv4Addr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
+ }
+
+ // Make sure it's all in one buffer.
+ vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
+ replyIPHeader := header.IPv4(vv.ToView())
+
+ // At this stage we only know it's an IP header so verify that much.
+ checker.IPv4(t, replyIPHeader,
+ checker.SrcAddr(ipv4Addr.Address),
+ checker.DstAddr(remoteIPv4Addr),
+ )
+
+ // All expected responses are ICMP packets.
+ if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
+ t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ }
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
+
+ // Sanity check the response.
+ switch replyICMPHeader.Type() {
+ case header.ICMPv4DstUnreachable:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ if !test.shouldFail || !test.expectICMP {
+ t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ return
+ case header.ICMPv4EchoReply:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPv4HeaderLength(ipHeaderLength),
+ checker.IPv4Options(test.options),
+ checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.ICMPv4(
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Seq(randomSequence),
+ checker.ICMPv4Ident(randomIdent),
+ checker.ICMPv4Checksum(),
+ ),
+ )
+ if test.shouldFail {
+ t.Fatalf("unexpected Echo Reply packet\n")
+ }
+ default:
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ }
+ })
}
- payload := buffer.NewVectorisedView(totalLength, views)
- return hdr, payload
}
// comparePayloads compared the contents of all the packets against the contents
@@ -117,9 +370,9 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
- source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
- source = append(source, sourcePacketInfo.Header.View()...)
- source = append(source, sourcePacketInfo.Data.ToView()...)
+ source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize])
+ vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views())
+ source = append(source, vv.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -132,8 +385,7 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
var reassembledPayload []byte
for i, packet := range packets {
// Confirm that the packet is valid.
- allBytes := packet.Header.View().ToVectorisedView()
- allBytes.Append(packet.Data)
+ allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
ip := header.IPv4(allBytes.ToView())
if !ip.IsValid(len(ip)) {
t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
@@ -144,12 +396,22 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
if got, want := len(ip), int(mtu); got > want {
t.Errorf("fragment is too large, got %d want %d", got, want)
}
- if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ if i == 0 {
+ got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size()
+ // sourcePacketInfo does not have NetworkHeader added, simulate one.
+ want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size()
+ // Check that it kept the transport header in packet.TransportHeader if
+ // it fits in the first fragment.
+ if want < int(mtu) && got != want {
+ t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ }
}
- if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
+ if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
+ if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want {
+ t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want)
+ }
if i < len(packets)-1 {
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
} else {
@@ -172,101 +434,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type errorChannel struct {
- *channel.Endpoint
- Ch chan *stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan *stack.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
-
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- }
- return nextError
-}
-
-type context struct {
- stack.Route
- linkEP *errorChannel
-}
-
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
- // Make the packet and write it.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- s.AddAddress(1, ipv4.ProtocolNumber, src)
- {
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
- }
- return context{
- Route: r,
- linkEP: ep,
- }
-}
-
func TestFragmentation(t *testing.T) {
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
manyPayloadViewsSizes[i] = 7
}
fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ extraHeaderReserveLength int
+ payloadViewsSizes []int
+ expectedFrags int
}{
{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
@@ -281,43 +461,29 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := &stack.PacketBuffer{
- Header: hdr,
- // Save the source payload because WritePacket will modify it.
- Data: payload.Clone(nil),
- }
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ source := pkt.Clone()
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
- }, &stack.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
+ }, pkt)
if err != nil {
- t.Errorf("err got %v, want %v", err, nil)
+ t.Errorf("got err = %s, want = nil", err)
}
- var results []*stack.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
+ if got := len(ep.WrittenPackets); got != ft.expectedFrags {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
}
-
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- compareFragments(t, results, source, ft.mtu)
+ compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
}
@@ -328,155 +494,376 @@ func TestFragmentationErrors(t *testing.T) {
fragTests := []struct {
description string
mtu uint32
- hdrLength int
+ transportHeaderLength int
payloadViewsSizes []int
- packetCollectorErrors []*tcpip.Error
+ err *tcpip.Error
+ allowPackets int
+ fragmentCount int
}{
- {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {
+ description: "NoFrag",
+ mtu: 2000,
+ transportHeaderLength: 0,
+ payloadViewsSizes: []int{1000},
+ err: tcpip.ErrAborted,
+ allowPackets: 0,
+ fragmentCount: 1,
+ },
+ {
+ description: "ErrorOnFirstFrag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadViewsSizes: []int{1000},
+ err: tcpip.ErrAborted,
+ allowPackets: 0,
+ fragmentCount: 3,
+ },
+ {
+ description: "ErrorOnSecondFrag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadViewsSizes: []int{1000},
+ err: tcpip.ErrAborted,
+ allowPackets: 1,
+ fragmentCount: 3,
+ },
+ {
+ description: "ErrorOnFirstFragMTUSmallerThanHeader",
+ mtu: 500,
+ transportHeaderLength: 1000,
+ payloadViewsSizes: []int{500},
+ err: tcpip.ErrAborted,
+ allowPackets: 0,
+ fragmentCount: 4,
+ },
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ r := buildRoute(t, ep)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
- }, &stack.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
- for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
- if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
- t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
- }
+ }, pkt)
+ if err != ft.err {
+ t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
}
- // We only need to check that last error because all the ones before are
- // nil.
- if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
- t.Errorf("err got %v, want %v", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
- t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
}
})
}
}
func TestInvalidFragments(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 6
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ autoChecksum bool // if true, the Checksum field will be overwritten.
+ }
+
// These packets have both IHL and TotalLength set to 0.
- testCases := []struct {
+ tests := []struct {
name string
- packets [][]byte
+ fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
}{
{
- "ihl_totallen_zero_valid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 0,
- },
- {
- "ihl_totallen_zero_invalid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset non-zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments,
+ FragmentOffset: 59776,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
{
- // Total Length of 37(20 bytes IP header + 17 bytes of
- // payload)
- // Frag Offset of 0x1ffe = 8190*8 = 65520
- // Leading to the fragment end to be past 65535.
- "ihl_totallen_valid_invalid_frag_offset_1",
- [][]byte{
- {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
- // The following 3 tests were found by running a fuzzer and were
- // triggering a panic in the IPv4 reassembler code.
{
- "ihl_less_than_ipv4_minimum_size_1",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 17 octets and Fragment offset 65520
+ // Leading to the fragment end to be past 65536.
+ name: "fragment ends past 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 17,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(17),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "ihl_less_than_ipv4_minimum_size_2",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 16 octets and fragment offset 65520
+ // Leading to the fragment end to be exactly 65536.
+ name: "fragment ends exactly at 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(16),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 0,
+ wantMalformedFragments: 0,
},
{
- "ihl_less_than_ipv4_minimum_size_3",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL less than IPv4 minimum size",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 1944,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize - 12,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 2,
+ wantMalformedFragments: 0,
},
{
- "fragment_with_short_total_len_extra_payload",
- [][]byte{
- {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "fragment with short TotalLength and extra payload",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 28816,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 4,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "multiple_fragments_with_more_fragments_set_to_false",
- [][]byte{
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ name: "multiple fragments with More Fragments flag set to false",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 128,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- const nicID tcpip.NICID = 42
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
},
})
+ e := channel.New(0, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
- var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
- var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
- ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicID, sniffer.New(ep))
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize + len(f.payload)
+ hdr := buffer.NewPrependable(pktSize)
- for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
- })
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+ copy(ip[header.IPv4MinimumSize:], f.payload)
+
+ if f.autoChecksum {
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ }
+
+ vv := hdr.View().ToVectorisedView()
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
}
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
}
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
}
})
@@ -486,12 +873,16 @@ func TestInvalidFragments(t *testing.T) {
// TestReceiveFragments feeds fragments in through the incoming packet path to
// test reassembly
func TestReceiveFragments(t *testing.T) {
- const addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- const addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- const nicID = 1
+ const (
+ nicID = 1
+
+ addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
+ addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
+ addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ )
// Build and return a UDP header containing payload.
- udpGen := func(payloadLen int, multiplier uint8) buffer.View {
+ udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View {
payload := buffer.NewView(payloadLen)
for i := 0; i < len(payload); i++ {
payload[i] = uint8(i) * multiplier
@@ -507,20 +898,32 @@ func TestReceiveFragments(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
sum = header.Checksum(payload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
return hdr.View()
}
// UDP header plus a payload of 0..256
- ipv4Payload1 := udpGen(256, 1)
- udpPayload1 := ipv4Payload1[header.UDPMinimumSize:]
+ ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2)
+ udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:]
+ ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2)
+ udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:]
// UDP header plus a payload of 0..256 in increments of 2.
- ipv4Payload2 := udpGen(128, 2)
- udpPayload2 := ipv4Payload2[header.UDPMinimumSize:]
+ ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2)
+ udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:]
+ // UDP header plus a payload of 0..256 in increments of 3.
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 791 section 3.1 page 14).
+ ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
+ udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
+ // Used to test the max reassembled payload length (65,535 octets).
+ ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
+ udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
id uint16
flags uint8
fragmentOffset uint16
@@ -536,22 +939,40 @@ func TestReceiveFragments(t *testing.T) {
name: "No fragmentation",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 0,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "No fragmentation with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2,
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
},
{
name: "More fragments without payload",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
expectedPayloads: nil,
@@ -560,10 +981,12 @@ func TestReceiveFragments(t *testing.T) {
name: "Non-zero fragment offset without payload",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 8,
- payload: ipv4Payload1,
+ payload: ipv4Payload1Addr1ToAddr2,
},
},
expectedPayloads: nil,
@@ -572,34 +995,108 @@ func TestReceiveFragments(t *testing.T) {
name: "Two fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments out of order",
+ fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload3Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2[:63],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 63,
+ payload: ipv4Payload3Addr1ToAddr2[63:],
+ },
+ },
+ expectedPayloads: nil,
},
{
name: "Second fragment has MoreFlags set",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
},
expectedPayloads: nil,
@@ -608,16 +1105,20 @@ func TestReceiveFragments(t *testing.T) {
name: "Two fragments with different IDs",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
},
expectedPayloads: nil,
@@ -626,52 +1127,122 @@ func TestReceiveFragments(t *testing.T) {
name: "Two interleaved fragmented packets",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload2[:64],
+ payload: ipv4Payload2Addr1ToAddr2[:64],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload1[64:],
+ payload: ipv4Payload1Addr1ToAddr2[64:],
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 2,
flags: 0,
fragmentOffset: 64,
- payload: ipv4Payload2[64:],
+ payload: ipv4Payload2Addr1ToAddr2[64:],
},
},
- expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr3ToAddr2[:32],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 32,
+ payload: ipv4Payload1Addr3ToAddr2[32:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
},
{
name: "Fragment without followup",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
id: 1,
flags: header.IPv4FlagMoreFragments,
fragmentOffset: 0,
- payload: ipv4Payload1[:64],
+ payload: ipv4Payload1Addr1ToAddr2[:64],
},
},
expectedPayloads: nil,
},
+ {
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Setup a stack and endpoint.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
if err := s.CreateNIC(nicID, e); err != nil {
@@ -711,16 +1282,16 @@ func TestReceiveFragments(t *testing.T) {
FragmentOffset: frag.fragmentOffset,
TTL: 64,
Protocol: uint8(header.UDPProtocolNumber),
- SrcAddr: addr1,
- DstAddr: addr2,
+ SrcAddr: frag.srcAddr,
+ DstAddr: frag.dstAddr,
})
vv := hdr.View().ToVectorisedView()
vv.AppendView(frag.payload)
- e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- })
+ }))
}
if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
@@ -743,3 +1314,189 @@ func TestReceiveFragments(t *testing.T) {
})
}
}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %s", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ // Parameterize the tests to run with both WritePacket and WritePackets.
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
+ }
+ return rt
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 3f71fc520..97adbcbd4 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -5,16 +5,19 @@ package(licenses = ["notice"])
go_library(
name = "ipv6",
srcs = [
+ "dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "ndp.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
- "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -35,10 +38,11 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
index d199ded6a..09ba133b1 100644
--- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
+++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
@@ -14,7 +14,7 @@
// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
-package stack
+package ipv6
import "strconv"
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 2ff7eedf4..8e9def6b8 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -39,8 +39,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// is truncated, which would cause IsValid to return false.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
- // original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -67,7 +68,60 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+}
+
+// getLinkAddrOption searches NDP options for a given link address option using
+// the provided getAddr function as a filter. Returns the link address if
+// found; otherwise, returns the zero link address value. Also returns true if
+// the options are valid as per the wire format, false otherwise.
+func getLinkAddrOption(it header.NDPOptionIterator, getAddr func(header.NDPOption) tcpip.LinkAddress) (tcpip.LinkAddress, bool) {
+ var linkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ return "", false
+ }
+ if done {
+ break
+ }
+ if addr := getAddr(opt); len(addr) != 0 {
+ // No RFCs define what to do when an NDP message has multiple Link-Layer
+ // Address options. Since no interface can have multiple link-layer
+ // addresses, we consider such messages invalid.
+ if len(linkAddr) != 0 {
+ return "", false
+ }
+ linkAddr = addr
+ }
+ }
+ return linkAddr, true
+}
+
+// getSourceLinkAddr searches NDP options for the source link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getSourceLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if src, ok := opt.(header.NDPSourceLinkLayerAddressOption); ok {
+ return src.EthernetAddress()
+ }
+ return ""
+ })
+}
+
+// getTargetLinkAddr searches NDP options for the target link address option.
+// Returns the link address if found; otherwise, returns the zero link address
+// value. Also returns true if the options are valid as per the wire format,
+// false otherwise.
+func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
+ return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress {
+ if dst, ok := opt.(header.NDPTargetLinkLayerAddressOption); ok {
+ return dst.EthernetAddress()
+ }
+ return ""
+ })
}
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
@@ -83,7 +137,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
h := header.ICMPv6(v)
- iph := header.IPv6(pkt.NetworkHeader)
+ iph := header.IPv6(pkt.NetworkHeader().View())
// Validate ICMPv6 checksum before processing the packet.
//
@@ -128,13 +182,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
@@ -144,22 +200,16 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// NDP messages cannot be fragmented. Also note that in the common case NDP
// datagrams are very small and ToView() will not incur allocations.
ns := header.NDPNeighborSolicit(payload.ToView())
- it, err := ns.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NS option, drop the packet.
+ targetAddr := ns.TargetAddress()
+
+ // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
+ // address.
+ if header.IsV6MulticastAddress(targetAddr) {
received.Invalid.Increment()
return
}
- targetAddr := ns.TargetAddress()
- s := r.Stack()
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now, drop this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// If the target address is tentative and the source of the packet is a
// unicast (specified) address, then the source of the packet is
// attempting to perform address resolution on the target. In this case,
@@ -172,7 +222,20 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if r.RemoteAddress == header.IPv6Any {
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
}
// Do not handle neighbor solicitations targeted to an address that is
@@ -184,39 +247,22 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// so the packet is processed as defined in RFC 4861, as per RFC 4862
// section 5.4.3.
- // Is the NS targetting us?
- if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ // Is the NS targeting us?
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
- // If the NS message contains the Source Link-Layer Address option, update
- // the link address cache with the value of the option.
- //
- // TODO(b/148429853): Properly process the NS message and do Neighbor
- // Unreachability Detection.
- var sourceLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
- }
-
- switch opt := opt.(type) {
- case header.NDPSourceLinkLayerAddressOption:
- // No RFCs define what to do when an NS message has multiple Source
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(sourceLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
+ it, err := ns.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
- sourceLinkAddr = opt.EthernetAddress()
- }
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
}
unspecifiedSource := r.RemoteAddress == header.IPv6Any
@@ -234,8 +280,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if unspecifiedSource {
received.Invalid.Increment()
return
+ } else if e.nud != nil {
+ e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
// ICMPv6 Neighbor Solicit messages are always sent to
@@ -274,8 +322,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
optsSerializer := header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
- packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
+ })
+ packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
na.SetSolicitedFlag(solicited)
@@ -291,9 +342,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- }); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -301,7 +350,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize {
received.Invalid.Increment()
return
}
@@ -311,28 +360,34 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// 5, NDP messages cannot be fragmented. Also note that in the common case
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
- it, err := na.Options().Iter(true)
- if err != nil {
- // If we have a malformed NDP NA option, drop the packet.
- received.Invalid.Increment()
- return
- }
-
targetAddr := na.TargetAddress()
- stack := r.Stack()
-
- if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now short-circuit this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- stack.DupTentativeAddrDetected(e.nicID, targetAddr)
+ //
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
+ return
+ }
+
+ it, err := na.Options().Iter(false /* check */)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
return
}
@@ -345,58 +400,64 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// TODO(b/143147598): Handle the scenario described above. Also inform the
// netstack integration that a duplicate address was detected outside of
// DAD.
+ targetLinkAddr, ok := getTargetLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
- //
- // TODO(b/148429853): Properly process the NA message and do Neighbor
- // Unreachability Detection.
- var targetLinkAddr tcpip.LinkAddress
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
- }
- if done {
- break
- }
-
- switch opt := opt.(type) {
- case header.NDPTargetLinkLayerAddressOption:
- // No RFCs define what to do when an NA message has multiple Target
- // Link-Layer Address options. Since no interface can have multiple
- // link-layer addresses, we consider such messages invalid.
- if len(targetLinkAddr) != 0 {
- received.Invalid.Increment()
- return
- }
-
- targetLinkAddr = opt.EthernetAddress()
+ if len(targetLinkAddr) != 0 {
+ if e.nud == nil {
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
+ return
}
- }
- if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ Solicited: na.SolicitedFlag(),
+ Override: na.OverrideFlag(),
+ IsRouter: na.RouterFlag(),
+ })
}
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
- icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize)
+ icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
if !ok {
received.Invalid.Increment()
return
}
- pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
- packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+
+ remoteLinkAddr := r.RemoteLinkAddress
+
+ // As per RFC 4291 section 2.7, multicast addresses must not be used as
+ // source addresses in IPv6 packets.
+ localAddr := r.LocalAddress
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ localAddr = ""
+ }
+
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ // Use the link address from the source of the original packet.
+ r.ResolveWith(remoteLinkAddr)
+
+ replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
+ Data: pkt.Data,
+ })
+ packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: pkt.Data,
- }); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -418,27 +479,75 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6RouterSolicit:
received.RouterSolicit.Increment()
- if !isNDPValid() {
+
+ //
+ // Validate the RS as per RFC 4861 section 6.1.1.
+ //
+
+ // Is the NDP payload of sufficient size to hold a Router Solictation?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
received.Invalid.Increment()
return
}
- case header.ICMPv6RouterAdvert:
- received.RouterAdvert.Increment()
+ stack := r.Stack()
- // Is the NDP payload of sufficient size to hold a Router
- // Advertisement?
- if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ // Is the networking stack operating as a router?
+ if !stack.Forwarding(ProtocolNumber) {
+ // ... No, silently drop the packet.
+ received.RouterOnlyPacketsDroppedByHost.Increment()
+ return
+ }
+
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
+ rs := header.NDPRouterSolicit(payload.ToView())
+ it, err := rs.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
received.Invalid.Increment()
return
}
- routerAddr := iph.SourceAddress()
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
+ received.Invalid.Increment()
+ return
+ }
+
+ // If the RS message has the source link layer option, update the link
+ // address cache with the link address for the source of the message.
+ if len(sourceLinkAddr) != 0 {
+ // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, it SHOULD be included on link layers that have addresses.
+ if r.RemoteAddress == header.IPv6Any {
+ received.Invalid.Increment()
+ return
+ }
+
+ if e.nud != nil {
+ // A RS with a specified source IP address modifies the NUD state
+ // machine in the same way a reachability probe would.
+ e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+ }
+
+ case header.ICMPv6RouterAdvert:
+ received.RouterAdvert.Increment()
//
// Validate the RA as per RFC 4861 section 6.1.2.
//
+ // Is the NDP payload of sufficient size to hold a Router Advertisement?
+ if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
+ received.Invalid.Increment()
+ return
+ }
+
+ routerAddr := iph.SourceAddress()
+
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
// ...No, silently drop the packet.
@@ -446,16 +555,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- // The remainder of payload must be only the router advertisement, so
- // payload.ToView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and ToView() will not incur allocations.
+ // Note that in the common case NDP datagrams are very small and ToView()
+ // will not incur allocations.
ra := header.NDPRouterAdvert(payload.ToView())
- opts := ra.Options()
+ it, err := ra.Options().Iter(false /* check */)
+ if err != nil {
+ // Options are not valid as per the wire format, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
- // Are options valid as per the wire format?
- if _, err := opts.Iter(true); err != nil {
- // ...No, silently drop the packet.
+ sourceLinkAddr, ok := getSourceLinkAddr(it)
+ if !ok {
received.Invalid.Increment()
return
}
@@ -465,12 +576,33 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// as RFC 4861 section 6.1.2 is concerned.
//
- // Tell the NIC to handle the RA.
- stack := r.Stack()
- rxNICID := r.NICID()
- stack.HandleNDPRA(rxNICID, routerAddr, ra)
+ // If the RA has the source link layer option, update the link address
+ // cache with the link address for the advertised router.
+ if len(sourceLinkAddr) != 0 && e.nud != nil {
+ e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ }
+
+ e.mu.Lock()
+ e.mu.ndp.handleRA(routerAddr, ra)
+ e.mu.Unlock()
case header.ICMPv6RedirectMsg:
+ // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating
+ // this redirect message, as per RFC 4871 section 7.3.3:
+ //
+ // "A Neighbor Cache entry enters the STALE state when created as a
+ // result of receiving packets other than solicited Neighbor
+ // Advertisements (i.e., Router Solicitations, Router Advertisements,
+ // Redirects, and Neighbor Solicitations). These packets contain the
+ // link-layer address of either the sender or, in the case of Redirect,
+ // the redirection target. However, receipt of these link-layer
+ // addresses does not confirm reachability of the forward-direction path
+ // to that node. Placing a newly created Neighbor Cache entry for which
+ // the link-layer address is known in the STALE state provides assurance
+ // that path failures are detected quickly. In addition, should a cached
+ // link-layer address be modified due to receiving one of the above
+ // messages, the state SHOULD also be set to STALE to provide prompt
+ // verification that the path to the new link-layer address is working."
received.RedirectMsg.Increment()
if !isNDPValid() {
received.Invalid.Increment()
@@ -494,8 +626,6 @@ const (
icmpV6LengthOffset = 25
)
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -504,7 +634,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
@@ -513,19 +643,26 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
}
- hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- copy(pkt[icmpV6OptOffset-len(addr):], addr)
- pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
- pkt[icmpV6LengthOffset] = 1
- copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
-
- length := uint16(hdr.UsedLength())
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
+ })
+ icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ icmpHdr.SetType(header.ICMPv6NeighborSolicit)
+ copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
+ icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
+ icmpHdr[icmpV6LengthOffset] = 1
+ copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress())
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ length := uint16(pkt.Size())
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
@@ -535,9 +672,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{
- Header: hdr,
- })
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -547,3 +682,159 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
return tcpip.LinkAddress([]byte(nil)), false
}
+
+// ======= ICMP Error packet generation =========
+
+// icmpReason is a marker interface for IPv6 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonParameterProblem is an error during processing of extension headers
+// or the fixed header defined in RFC 4443 section 3.4.
+type icmpReasonParameterProblem struct {
+ code header.ICMPv6Code
+
+ // respondToMulticast indicates that we are sending a packet that falls under
+ // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondToMulticast bool
+
+ // pointer is defined in the RFC 4443 setion 3.4 which reads:
+ //
+ // Pointer Identifies the octet offset within the invoking packet
+ // where the error was detected.
+ //
+ // The pointer will point beyond the end of the ICMPv6
+ // packet if the field in error is beyond what can fit
+ // in the maximum size of an ICMPv6 error message.
+ pointer uint32
+}
+
+func (*icmpReasonParameterProblem) isICMPReason() {}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv6 and sends it.
+func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ stats := r.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ // Only send ICMP error if the address is not a multicast v6
+ // address and the source is not the unspecified address.
+ //
+ // There are exceptions to this rule.
+ // See: point e.3) RFC 4443 section-2.4
+ //
+ // (e) An ICMPv6 error message MUST NOT be originated as a result of
+ // receiving the following:
+ //
+ // (e.1) An ICMPv6 error message.
+ //
+ // (e.2) An ICMPv6 redirect message [IPv6-DISC].
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ //
+ var allowResponseToMulticast bool
+ if reason, ok := reason.(*icmpReasonParameterProblem); ok {
+ allowResponseToMulticast = reason.respondToMulticast
+ }
+
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
+ if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
+ // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
+ // Unfortunately at this time ICMP Packets do not have a transport
+ // header separated out. It is in the Data part so we need to
+ // separate it out now. We will just pretend it is a minimal length
+ // ICMP packet as we don't really care if any later bits of a
+ // larger ICMP packet are in the header view or in the Data view.
+ transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize)
+ if !ok {
+ return nil
+ }
+ typ := header.ICMPv6(transport).Type()
+ if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
+ return nil
+ }
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ available := int(mtu) - headerLen
+ if available < header.IPv6MinimumSize {
+ return nil
+ }
+ payloadLen := network.Size() + transport.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ payload.CapLength(payloadLen)
+
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+ newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+
+ icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ icmpHdr.SetType(header.ICMPv6ParamProblem)
+ icmpHdr.SetCode(reason.code)
+ icmpHdr.SetTypeSpecific(reason.pointer)
+ counter = sent.ParamProblem
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ counter = sent.DstUnreachable
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
+ err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
+ if err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 52a01b44e..31370c1d4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -31,9 +31,14 @@ import (
)
const (
+ nicID = 1
+
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
var (
@@ -46,7 +51,10 @@ type stubLinkEndpoint struct {
}
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+ // Indicate that resolution for link layer addresses is required to send
+ // packets over this link. This is needed so the NIC knows to allocate a
+ // neighbor table.
+ return stack.CapabilityResolutionRequired
}
func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
@@ -67,7 +75,8 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
+ return stack.TransportPacketHandled
}
type stubLinkAddressCache struct {
@@ -81,16 +90,212 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco
func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
}
+type stubNUDHandler struct{}
+
+var _ stack.NUDHandler = (*stubNUDHandler)(nil)
+
+func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+}
+
+func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+}
+
+func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
+}
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct{}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return nil
+}
+
func TestICMPCounts(t *testing.T) {
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ {
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ }
+ defer r.Release()
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ }{
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
+ }
+
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(&r, pkt)
+ }
+
+ for _, typ := range types {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
+ }
+
+ // Construct an empty ICMP packet so that
+ // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+
+ icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
+ visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
+ if got, want := s.Value(), uint64(1); got != want {
+ t.Errorf("got %s = %d, want = %d", name, got, want)
+ }
+ })
+ if t.Failed() {
+ t.Logf("stats:\n%+v", s.Stats())
+ }
+ })
+ }
+}
+
+func TestICMPCountsWithNeighborCache(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: true,
})
{
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
}
@@ -102,7 +307,7 @@ func TestICMPCounts(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -111,14 +316,16 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -180,7 +387,11 @@ func TestICMPCounts(t *testing.T) {
}
handleIPv6Payload := func(icmp header.ICMPv6) {
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
@@ -188,10 +399,7 @@ func TestICMPCounts(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, &stack.PacketBuffer{
- NetworkHeader: buffer.View(ip),
- Data: buffer.View(icmp).ToVectorisedView(),
- })
+ ep.HandlePacket(&r, pkt)
}
for _, typ := range types {
@@ -248,35 +456,34 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
s0: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
}),
s1: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
}),
}
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
+ if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
- if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
+ if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
+ if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
@@ -287,7 +494,7 @@ func newTestContext(t *testing.T) *testContext {
c.s0.SetRouteTable(
[]tcpip.Route{{
Destination: subnet0,
- NIC: 1,
+ NIC: nicID,
}},
)
subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
@@ -297,7 +504,7 @@ func newTestContext(t *testing.T) *testContext {
c.s1.SetRouteTable(
[]tcpip.Route{{
Destination: subnet1,
- NIC: 1,
+ NIC: nicID,
}},
)
@@ -321,12 +528,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
pi, _ := args.src.ReadContext(context.Background())
{
- views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
- size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
- vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{
- Data: vv,
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()),
})
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt)
}
if pi.Proto != ProtocolNumber {
@@ -338,7 +543,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
}
- ipv6 := header.IPv6(pi.Pkt.Header.View())
+ // Pull the full payload since network header. Needed for header.IPv6 to
+ // extract its payload.
+ ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader()))
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
t.Errorf("unexpected transport protocol number %d", transProto)
@@ -358,9 +565,9 @@ func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
- r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
}
defer r.Release()
@@ -375,14 +582,14 @@ func TestLinkResolution(t *testing.T) {
var wq waiter.Queue
ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
for {
- _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
+ _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}})
if resCh != nil {
if err != tcpip.ErrNoLinkAddress {
- t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
+ t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err)
}
for _, args := range []routeArgs{
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
@@ -398,7 +605,7 @@ func TestLinkResolution(t *testing.T) {
continue
}
if err != nil {
- t.Fatalf("ep.Write(_) = _, _, %s", err)
+ t.Fatalf("ep.Write(_) = (_, _, %s)", err)
}
break
}
@@ -423,6 +630,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
size int
extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
}{
{
name: "DstUnreachable",
@@ -479,6 +687,8 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
+ // Hosts MUST silently discard any received Router Solicitation messages.
+ routerOnly: true,
},
{
name: "RouterAdvert",
@@ -515,83 +725,133 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- handleIPv6Payload := func(checksum bool) {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- if checksum {
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
+ }
+ t.Run(name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr0)
+
+ // Indicate that resolution for link layer addresses is required to
+ // send packets over this link. This is needed so the NIC knows to
+ // allocate a neighbor table.
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: test.useNeighborCache,
+ })
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(checksum bool) {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ if checksum {
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ }
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
+ })
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Router only count should not have increased.
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ if !isRouter && typ.routerOnly && test.useNeighborCache {
+ // Router only count should have increased.
+ if got := routerOnly.Value(); got != 1 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
+ }
+ }
+ })
}
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
- })
- }
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
}
})
}
@@ -692,13 +952,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Run(typ.name, func(t *testing.T) {
e := channel.New(10, 1280, linkAddr0)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
{
@@ -709,7 +969,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
@@ -717,12 +977,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
icmpSize := size + payloadSize
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(typ)
- payloadFn(pkt.Payload())
+ icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize))
+ icmpHdr.SetType(typ)
+ payloadFn(icmpHdr.Payload())
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{}))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -733,9 +993,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
+ e.InjectInbound(ProtocolNumber, pkt)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -747,7 +1008,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -758,13 +1019,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
@@ -869,14 +1130,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Run(typ.name, func(t *testing.T) {
e := channel.New(10, 1280, linkAddr0)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -886,21 +1147,21 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
s.SetRouteTable(
[]tcpip.Route{{
Destination: subnet,
- NIC: 1,
+ NIC: nicID,
}},
)
}
handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
+ icmpHdr := header.ICMPv6(hdr.Prepend(size))
+ icmpHdr.SetType(typ)
payload := buffer.NewView(payloadSize)
payloadFn(payload)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView()))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -911,9 +1172,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
})
+ e.InjectInbound(ProtocolNumber, pkt)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -925,7 +1187,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Without setting checksum, the incoming packet should
@@ -936,13 +1198,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
}
// Rx count of type typ.typ should not have increased.
if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// When checksum is set, it should be received.
handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ t.Fatalf("got = %d, want = 0", got)
}
// Invalid count should not have increased again.
if got := invalid.Value(); got != 1 {
@@ -951,3 +1213,50 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ snaddr := header.SolicitedNodeAddr(lladdr0)
+ mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
+
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: mcaddr,
+ },
+ }
+
+ for _, test := range tests {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 95fbcf2d1..c8a3e0b34 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,23 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ipv6 contains the implementation of the ipv6 network protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.NewProtocol() as one of the network
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// ipv6.ProtocolNumber as the network protocol number when calling
-// Stack.NewEndpoint().
+// Package ipv6 contains the implementation of the ipv6 network protocol.
package ipv6
import (
"fmt"
+ "sort"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
- "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -45,46 +42,313 @@ const (
DefaultTTL = 64
)
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+var _ stack.NDPEndpoint = (*endpoint)(nil)
+var _ NDPEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
+ nic stack.NetworkInterface
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
+ nud stack.NUDHandler
dispatcher stack.TransportDispatcher
- fragmentation *fragmentation.Fragmentation
protocol *protocol
+ stack *stack.Stack
+
+ // enabled is set to 1 when the endpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ ndp ndpState
+ }
}
-// DefaultTTL is the default hop limit for this endpoint.
-func (e *endpoint) DefaultTTL() uint8 {
- return e.protocol.DefaultTTL()
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it is passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
}
-// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
-// the network layer max header length.
-func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+// InvalidateDefaultRouter implements stack.NDPEndpoint.
+func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.invalidateDefaultRouter(rtr)
+}
+
+// SetNDPConfigurations implements NDPEndpoint.
+func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
+ c.validate()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.configs = c
+}
+
+// hasTentativeAddr returns true if addr is tentative on e.
+func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
+ e.mu.RLock()
+ addressEndpoint := e.getAddressRLocked(addr)
+ e.mu.RUnlock()
+ return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform e that a tentative addr is a
+// duplicate on a link.
+//
+// dupTentativeAddrDetected removes the tentative address if it exists. If the
+// address was generated via SLAAC, an attempt is made to generate a new
+// address.
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return tcpip.ErrBadAddress
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
+ // attempt will be made to generate a new address for it.
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ return err
+ }
+
+ prefix := addressEndpoint.AddressWithPrefix().Subnet()
+
+ switch t := addressEndpoint.ConfigType(); t {
+ case stack.AddressConfigStatic:
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.regenerateSLAACAddr(prefix)
+ case stack.AddressConfigSlaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ default:
+ panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ }
+
+ return nil
+}
+
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.Enabled() {
+ return
+ }
+
+ if forwarding {
+ // When transitioning into an IPv6 router, host-only state (NDP discovered
+ // routers, discovered on-link prefixes, and auto-generated addresses) is
+ // cleaned up/invalidated and NDP router solicitations are stopped.
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(true /* hostOnly */)
+ } else {
+ // When transitioning into an IPv6 host, NDP router solicitations are
+ // started.
+ e.mu.ndp.startSolicitingRouters()
+ }
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ //
+ // Also auto-generate an IPv6 link-local address based on the endpoint's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the endpoint
+ // was last enabled, other devices may have acquired the same addresses.
+ var err *tcpip.Error
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if !header.IsV6UnicastAddress(addr) {
+ return true
+ }
+
+ switch addressEndpoint.GetKind() {
+ case stack.Permanent:
+ addressEndpoint.SetKind(stack.PermanentTentative)
+ fallthrough
+ case stack.PermanentTentative:
+ err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint)
+ return err == nil
+ default:
+ return true
+ }
+ })
+ if err != nil {
+ return err
+ }
+
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
+
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyway.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !e.protocol.Forwarding() {
+ e.mu.ndp.startSolicitingRouters()
+ }
+
+ return nil
+}
+
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
}
-// ID returns the ipv6 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.stopDADForPermanentAddressesLocked()
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
+ }
}
-// PrefixLen returns the ipv6 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
+// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) stopDADForPermanentAddressesLocked() {
+ // Stop DAD for all the tentative unicast addresses.
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return true
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if header.IsV6UnicastAddress(addr) {
+ e.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+
+ return true
+ })
}
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+// DefaultTTL is the default hop limit for this endpoint.
+func (e *endpoint) DefaultTTL() uint8 {
+ return e.protocol.DefaultTTL()
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
@@ -101,9 +365,9 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 {
- length := uint16(hdr.UsedLength() + payloadSize)
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+ length := uint16(pkt.Size())
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(params.Protocol),
@@ -112,25 +376,46 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- return ip
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
- pkt.NetworkHeader = buffer.View(ip)
+ e.addIPHeader(r, pkt, params)
+
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
if r.Loop&stack.PacketLoop != 0 {
- // The inbound path expects the network header to still be in
- // the PacketBuffer's Data field.
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
+ e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // The inbound path expects an unparsed packet.
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
loopedR.Release()
}
@@ -138,8 +423,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+ return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -152,13 +441,57 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params)
- pb.NetworkHeader = buffer.View(ip)
+ e.addIPHeader(r, pb, params)
+ }
+
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
+ return n, err
+ }
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+
+ // Slow path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
+ }
+ n++
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
@@ -171,23 +504,47 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
- h := header.IPv6(pkt.NetworkHeader)
- if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) {
+ if !e.isEnabled() {
+ return
+ }
+
+ h := header.IPv6(pkt.NetworkHeader().View())
+ if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 4291 section 2.7:
+ // Multicast addresses must not be used as source addresses in IPv6
+ // packets or appear in any Routing header.
+ if header.IsV6MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
// - Any other payload data.
- vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView()
- vv.AppendView(pkt.TransportHeader)
+ vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView()
+ vv.AppendView(pkt.TransportHeader().View())
vv.Append(pkt.Data)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
- for firstHeader := true; ; firstHeader = false {
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and need not be forwarded.
+ ipt := e.protocol.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
+ return
+ }
+
+ for {
+ // Keep track of the start of the previous header so we can report the
+ // special case of a Hop by Hop at a location other than at the start.
+ previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -201,11 +558,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
- // (unrecognized next header) error in response to an extension header's
- // Next Header field with the Hop By Hop extension header identifier.
- if !firstHeader {
+ if previousHeaderStart != 0 {
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ }, pkt)
return
}
@@ -227,13 +584,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
@@ -244,16 +613,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.4, if a node encounters a routing header with
// an unrecognized routing type value, with a non-zero Segments Left
// value, the node must discard the packet and send an ICMP Parameter
- // Problem, Code 0. If the Segments Left is 0, the node must ignore the
- // Routing extension header and process the next header in the packet.
+ // Problem, Code 0 to the packet's Source Address, pointing to the
+ // unrecognized Routing Type.
+ //
+ // If the Segments Left is 0, the node must ignore the Routing extension
+ // header and process the next header in the packet.
//
// Note, the stack does not yet handle any type of routing extension
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
- // unrecognized routing types with a non-zero Segments Left value.
if extHdr.SegmentsLeft() != 0 {
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
return
}
@@ -286,7 +659,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
it, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -329,32 +702,44 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// The packet is a fragment, let's try to reassemble it.
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- last := start + uint16(fragmentPayloadLen) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < start {
+ // Drop the fragment if the size of the reassembled payload would exceed
+ // the maximum payload size.
+ if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
- var ready bool
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
- pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf)
+ data, proto, ready, err := e.protocol.fragmentation.Process(
+ // IPv6 ignores the Protocol field since the ID only needs to be unique
+ // across source-destination pairs, as per RFC 8200 section 4.5.
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: extHdr.ID(),
+ },
+ start,
+ start+uint16(fragmentPayloadLen)-1,
+ extHdr.More(),
+ uint8(rawPayload.Identifier),
+ rawPayload.Buf,
+ )
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
+ pkt.Data = data
if ready {
// We create a new iterator with the reassembled packet because we could
// have more extension headers in the reassembled payload, as per RFC
- // 8200 section 4.5.
- it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ // 8200 section 4.5. We also use the NextHeader value from the first
+ // fragment.
+ it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data)
}
case header.IPv6DestinationOptionsExtHdr:
@@ -376,13 +761,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
@@ -398,24 +795,58 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
- extHdr.Buf.TrimFront(len(pkt.TransportHeader))
+ extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
+ r.Stats().IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
- e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ case stack.TransportPacketHandled:
+ case stack.TransportPacketDestinationPortUnreachable:
+ // As per RFC 4443 section 3.1:
+ // A destination node SHOULD originate a Destination Unreachable
+ // message with Code 4 in response to a packet for which the
+ // transport protocol (e.g., UDP) has no listener, if that transport
+ // protocol has no alternative means to inform the sender.
+ _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC 8200 section 4. (page 7):
+ // Extension headers are numbered from IANA IP Protocol Numbers
+ // [IANA-PN], the same values used for IPv4 and IPv6. When
+ // processing a sequence of Next Header values in a packet, the
+ // first one that is not an extension header [IANA-EH] indicates
+ // that the next item in the packet is the corresponding upper-layer
+ // header.
+ // With more related information on page 8:
+ // If, as a result of processing a header, the destination node is
+ // required to proceed to the next header but the Next Header value
+ // in the current header is unrecognized by the node, it should
+ // discard the packet and send an ICMP Parameter Problem message to
+ // the source of the packet, with an ICMP Code value of 1
+ // ("unrecognized Next Header type encountered") and the ICMP
+ // Pointer field containing the offset of the unrecognized value
+ // within the original packet.
+ //
+ // Which when taken together indicate that an unknown protocol should
+ // be treated as an unrecognized next header value.
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
+ }
}
default:
- // If we receive a packet for an extension header we do not yet handle,
- // drop the packet for now.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ _ = returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
r.Stats().UnknownProtocolRcvdPackets.Increment()
return
}
@@ -423,18 +854,340 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
// Close cleans up resources associated with the endpoint.
-func (*endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.disableLocked()
+ e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */)
+ e.stopDADForPermanentAddressesLocked()
+ e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
+
+ e.protocol.forgetEndpoint(e)
+}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ // TODO(b/169350103): add checks here after making sure we no longer receive
+ // an empty address.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+}
+
+// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
+// with locking requirements.
+//
+// addAndAcquirePermanentAddressLocked also joins the passed address's
+// solicited-node multicast group and start duplicate address detection.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err != nil {
+ return nil, err
+ }
+
+ if !header.IsV6UnicastAddress(addr.Address) {
+ return addressEndpoint, nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
+ return nil, err
+ }
+
+ addressEndpoint.SetKind(stack.PermanentTentative)
+
+ if e.Enabled() {
+ if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil {
+ return nil, err
+ }
+ }
+
+ return addressEndpoint, nil
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return e.removePermanentEndpointLocked(addressEndpoint, true)
+}
+
+// removePermanentEndpointLocked is like removePermanentAddressLocked except
+// it works with a stack.AddressEndpoint.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+ addr := addressEndpoint.AddressWithPrefix()
+ unicast := header.IsV6UnicastAddress(addr.Address)
+ if unicast {
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ switch addressEndpoint.ConfigType() {
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case stack.AddressConfigSlaacTemp:
+ e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ }
+ }
+
+ if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil {
+ return err
+ }
+
+ if !unicast {
+ return nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
+ return nil
+}
+
+// hasPermanentAddressLocked returns true if the endpoint has a permanent
+// address equal to the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return false
+ }
+ return addressEndpoint.GetKind().IsPermanent()
+}
+
+// getAddressRLocked returns the endpoint for the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
+}
+
+// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with
+// locking requirements.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB)
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ // addrCandidate is a candidate for Source Address Selection, as per
+ // RFC 6724 section 5.
+ type addrCandidate struct {
+ addressEndpoint stack.AddressEndpoint
+ scope header.IPv6AddressScope
+ }
+
+ if len(remoteAddr) == 0 {
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ var cs []addrCandidate
+ e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !addressEndpoint.IsAssigned(allowExpired) {
+ return
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
+ }
+
+ cs = append(cs, addrCandidate{
+ addressEndpoint: addressEndpoint,
+ scope: scope,
+ })
+ })
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return true
+ }
+ if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
+ if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
+ return saTemp
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if c.addressEndpoint.IncRef() {
+ return c.addressEndpoint
+ }
+ }
+
+ return nil
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV6MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
+
type protocol struct {
+ stack *stack.Stack
+
+ mu struct {
+ sync.RWMutex
+
+ eps map[*endpoint]struct{}
+ }
+
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
defaultTTL uint32
+
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
+ fragmentation *fragmentation.Fragmentation
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
+ ndpConfigs NDPConfigurations
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
+ // autoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -459,24 +1212,43 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- return &endpoint{
- nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
+ nud: nud,
dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
- }, nil
+ }
+ e.mu.addressableEndpointState.Init(e)
+ e.mu.ndp = ndpState{
+ ep: e,
+ configs: p.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+ }
+ e.mu.ndp.initializeTempAddrState()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.mu.eps[e] = struct{}{}
+ return e
+}
+
+func (p *protocol) forgetEndpoint(e *endpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, e)
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case tcpip.DefaultTTLOption:
- p.SetDefaultTTL(uint8(v))
+ case *tcpip.DefaultTTLOption:
+ p.SetDefaultTTL(uint8(*v))
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -484,7 +1256,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
@@ -510,77 +1282,43 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return 0, false, false
}
- ipHdr := header.IPv6(hdr)
- // dataClone consists of:
- // - Any IPv6 header bytes after the first 40 (i.e. extensions).
- // - The transport header, if present.
- // - Any other payload data.
- views := [8]buffer.View{}
- dataClone := pkt.Data.Clone(views[:])
- dataClone.TrimFront(header.IPv6MinimumSize)
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+ return proto, !fragMore && fragOffset == 0, true
+}
- // Iterate over the IPv6 extensions to find their length.
- //
- // Parsing occurs again in HandlePacket because we don't track the
- // extensions in PacketBuffer. Unfortunately, that means HandlePacket
- // has to do the parsing work again.
- var nextHdr tcpip.TransportProtocolNumber
- foundNext := true
- extensionsSize := 0
-traverseExtensions:
- for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
- if err != nil {
- break
- }
- // If we exhaust the extension list, the entire packet is the IPv6 header
- // and (possibly) extensions.
- if done {
- extensionsSize = dataClone.Size()
- foundNext = false
- break
- }
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
- switch extHdr := extHdr.(type) {
- case header.IPv6FragmentExtHdr:
- // If this is an atomic fragment, we don't have to treat it specially.
- if !extHdr.More() && extHdr.FragmentOffset() == 0 {
- continue
- }
- // This is a non-atomic fragment and has to be re-assembled before we can
- // examine the payload for a transport header.
- foundNext = false
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&p.forwarding, 1) == 0
+ }
+ return atomic.SwapUint32(&p.forwarding, 0) == 1
+}
- case header.IPv6RawPayloadHeader:
- // We've found the payload after any extensions.
- extensionsSize = dataClone.Size() - extHdr.Buf.Size()
- nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
- break traverseExtensions
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
- default:
- // Any other extension is a no-op, keep looping until we find the payload.
- }
+ if !p.setForwarding(v) {
+ return
}
- // Put the IPv6 header with extensions in pkt.NetworkHeader.
- hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize)
- if !ok {
- panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ for ep := range p.mu.eps {
+ ep.transitionForwarding(v)
}
- ipHdr = header.IPv6(hdr)
-
- pkt.NetworkHeader = hdr
- pkt.Data.TrimFront(len(hdr))
- pkt.Data.CapLength(int(ipHdr.PayloadLength()))
-
- return nextHdr, foundNext, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
@@ -593,7 +1331,69 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-// NewProtocol returns an IPv6 network protocol.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{defaultTTL: DefaultTTL}
+// Options holds options to configure a new protocol.
+type Options struct {
+ // NDPConfigs is the default NDP configurations used by interfaces.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
+ // Note, setting this to true does not mean that a link-local address is
+ // assigned right away, or at all. If Duplicate Address Detection is enabled,
+ // an address is only assigned if it successfully resolves. If it fails, no
+ // further attempts are made to auto-generate an IPv6 link-local adddress.
+ //
+ // The generated link-local address follows RFC 4291 Appendix A guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
+}
+
+// NewProtocolWithOptions returns an IPv6 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
+ opts.NDPConfigs.validate()
+
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+
+ ndpDisp: opts.NDPDisp,
+ ndpConfigs: opts.NDPConfigs,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ }
+ p.mu.eps = make(map[*endpoint]struct{})
+ p.SetDefaultTTL(DefaultTTL)
+ return p
+ }
+}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 213ff64f2..d7f82973b 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,13 +15,16 @@
package ipv6
import (
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -65,9 +68,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
stats := s.Stats().ICMP.V6PacketsReceived
@@ -123,9 +126,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
stat := s.Stats().UDP.PacketsReceived
@@ -139,18 +142,18 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
tests := []struct {
name string
- protocolFactory stack.TransportProtocol
+ protocolFactory stack.TransportProtocolFactory
rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6, testReceiveICMP},
+ {"UDP", udp.NewProtocol, testReceiveUDP},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
@@ -172,11 +175,11 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
tests := []struct {
name string
- protocolFactory stack.TransportProtocol
+ protocolFactory stack.TransportProtocolFactory
rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6, testReceiveICMP},
+ {"UDP", udp.NewProtocol, testReceiveUDP},
}
snmc := header.SolicitedNodeAddr(addr2)
@@ -184,8 +187,8 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -271,7 +274,7 @@ func TestAddIpv6Address(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -299,11 +302,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
shouldAccept bool
+ // Should we expect an ICMP response and if so, with what contents?
+ expectICMP bool
+ ICMPType header.ICMPv6Type
+ ICMPCode header.ICMPv6Code
+ pointer uint32
+ multicast bool
}{
{
name: "None",
extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
shouldAccept: true,
+ expectICMP: false,
},
{
name: "hopbyhop with unknown option skippable action",
@@ -334,9 +344,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action",
+ name: "hopbyhop with unknown option discard and send icmp action (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -346,12 +377,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -362,39 +399,97 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ multicast: true,
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: true,
},
{
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 1, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6FixedHeaderSize + 2,
},
{
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 0, 0, 0, 0,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
},
{
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
+ expectICMP: false,
},
{
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{},
+ noNextHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
name: "destination with unknown option skippable action",
@@ -410,6 +505,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: true,
+ expectICMP: false,
},
{
name: "destination with unknown option discard action",
@@ -425,9 +521,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action",
+ name: "destination with unknown option discard and send icmp action (muilticast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -437,12 +554,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
}, destinationExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action unless multicast dest",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -453,22 +576,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "routing - atomic fragment",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ nextHdr, 1,
- // Fragment extension header.
- nextHdr, 0, 0, 0, 1, 2, 3, 4,
- }, routingExtHdrID
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
+ }, destinationExtHdrID
},
- shouldAccept: true,
+ shouldAccept: false,
+ expectICMP: false,
+ multicast: true,
},
{
name: "atomic fragment - routing",
@@ -502,12 +636,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
return []byte{
// Routing extension header.
hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
// Hop By Hop extension header with skippable unknown option.
nextHdr, 0, 62, 4, 1, 2, 3, 4,
}, routingExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
+ name: "routing - hop by hop (with send icmp unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
+
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
},
{
name: "No next header",
@@ -551,6 +715,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
@@ -571,16 +736,17 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -588,6 +754,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ // Add a default route so that a return packet knows where to go.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -629,17 +803,21 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
HopLimit: 255,
SrcAddr: addr1,
- DstAddr: addr2,
+ DstAddr: dstAddr,
})
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
stats := s.Stats().UDP.PacketsReceived
@@ -648,6 +826,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 0", got)
}
+ if !test.expectICMP {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpected packet received: %#v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected packet wasn't written out")
+ }
+
+ // Pack the output packet into a single buffer.View as the checkers
+ // assume that.
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
+ t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
+ }
+
+ ipHdr := header.IPv6(pkt)
+ checker.IPv6(t, ipHdr, checker.ICMPv6(
+ checker.ICMPv6Type(test.ICMPType),
+ checker.ICMPv6Code(test.ICMPCode)))
+
+ // We know we are looking at no extension headers in the error ICMP
+ // packets.
+ icmpPkt := header.ICMPv6(ipHdr.Payload())
+ // We know we sent small packets that won't be truncated when reflected
+ // back to us.
+ originalPacket := icmpPkt.Payload()
+ if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
+ t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
+ }
+ if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
+ t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
+ }
return
}
@@ -673,20 +889,28 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
nextHdr uint8
data buffer.VectorisedView
}
func TestReceiveIPv6Fragments(t *testing.T) {
- const nicID = 1
- const udpPayload1Length = 256
- const udpPayload2Length = 128
- const fragmentExtHdrLen = 8
- // Note, not all routing extension headers will be 8 bytes but this test
- // uses 8 byte routing extension headers for most sub tests.
- const routingExtHdrLen = 8
-
- udpGen := func(payload []byte, multiplier uint8) buffer.View {
+ const (
+ nicID = 1
+ udpPayload1Length = 256
+ udpPayload2Length = 128
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 8200 section 4.5).
+ udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
+ fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ routingExtHdrLen = 8
+ )
+
+ udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View {
payloadLen := len(payload)
for i := 0; i < payloadLen; i++ {
payload[i] = uint8(i) * multiplier
@@ -702,19 +926,31 @@ func TestReceiveIPv6Fragments(t *testing.T) {
Length: uint16(udpLength),
})
copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
sum = header.Checksum(payload, sum)
u.SetChecksum(^u.CalculateChecksum(sum))
return hdr.View()
}
- var udpPayload1Buf [udpPayload1Length]byte
- udpPayload1 := udpPayload1Buf[:]
- ipv6Payload1 := udpGen(udpPayload1, 1)
+ var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:]
+ ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2)
+
+ var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:]
+ ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2)
+
+ var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte
+ udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:]
+ ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2)
- var udpPayload2Buf [udpPayload2Length]byte
- udpPayload2 := udpPayload2Buf[:]
- ipv6Payload2 := udpGen(udpPayload2, 2)
+ var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte
+ udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
+ ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
+
+ var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte
+ udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:]
+ ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2)
tests := []struct {
name string
@@ -726,34 +962,60 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "No fragmentation",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: uint8(header.UDPProtocolNumber),
- data: ipv6Payload1.ToVectorisedView(),
+ data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Atomic fragment",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1Addr1ToAddr2,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Atomic fragment with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1),
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
- ipv6Payload1,
+ ipv6Payload3Addr1ToAddr2,
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
},
{
name: "Two fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -763,31 +1025,189 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments out of order",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with different Next Header values",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ // NextHeader value is different than the one in the first fragment, so
+ // this NextHeader should be ignored.
+ buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[:63],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[63:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
},
{
name: "Two fragments with different IDs",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -797,21 +1217,23 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -819,9 +1241,49 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: nil,
},
{
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:65520],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[65520:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -836,14 +1298,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Routing extension header.
//
@@ -855,17 +1319,19 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Two fragments with per-fragment routing header with non-zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -880,14 +1346,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: routingExtHdrID,
data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64,
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Routing extension header.
//
@@ -899,7 +1367,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -910,6 +1378,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -924,31 +1394,35 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Segments left = 0.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
},
{
name: "Two fragments with routing header with non-zero segments left",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
routingExtHdrLen+fragmentExtHdrLen+64,
@@ -963,21 +1437,23 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Segments left = 1.
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 9, More = false, ID = 1
buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
@@ -988,6 +1464,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with zero segments left across fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is fragmentExtHdrLen+8 because the
@@ -1008,12 +1486,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
// the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1),
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
@@ -1023,7 +1503,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header (part 2)
buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
- ipv6Payload1,
+ ipv6Payload1Addr1ToAddr2,
},
),
},
@@ -1034,6 +1514,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with routing header with non-zero segments left across fragments",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is fragmentExtHdrLen+8 because the
@@ -1054,12 +1536,14 @@ func TestReceiveIPv6Fragments(t *testing.T) {
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
// The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
// the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1),
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
@@ -1069,7 +1553,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Routing extension header (part 2)
buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
- ipv6Payload1,
+ ipv6Payload1Addr1ToAddr2,
},
),
},
@@ -1082,6 +1566,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
name: "Two fragments with atomic",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -1091,47 +1577,53 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
// This fragment has the same ID as the other fragments but is an atomic
// fragment. It should not interfere with the other fragments.
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2),
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2),
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 0, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
- ipv6Payload2,
+ ipv6Payload2Addr1ToAddr2,
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload2, udpPayload1},
+ expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2},
},
{
name: "Two interleaved fragmented packets",
fragments: []fragmentData{
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+64,
@@ -1141,11 +1633,13 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
- ipv6Payload1[:64],
+ ipv6Payload1Addr1ToAddr2[:64],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
fragmentExtHdrLen+32,
@@ -1155,48 +1649,122 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Fragment offset = 0, More = true, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
- ipv6Payload2[:32],
+ ipv6Payload2Addr1ToAddr2[:32],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1)-64,
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 8, More = false, ID = 1
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
- ipv6Payload1[64:],
+ ipv6Payload1Addr1ToAddr2[64:],
},
),
},
{
+ srcAddr: addr1,
+ dstAddr: addr2,
nextHdr: fragmentExtHdrID,
data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2)-32,
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32,
[]buffer.View{
// Fragment extension header.
//
// Fragment offset = 4, More = false, ID = 2
buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
- ipv6Payload2[32:],
+ ipv6Payload2Addr1ToAddr2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[:32],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[32:],
},
),
},
},
- expectedPayloads: [][]byte{udpPayload1, udpPayload2},
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
e := channel.New(0, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -1231,16 +1799,16 @@ func TestReceiveIPv6Fragments(t *testing.T) {
PayloadLength: uint16(f.data.Size()),
NextHeader: f.nextHdr,
HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: addr2,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
})
vv := hdr.View().ToVectorisedView()
vv.Append(f.data)
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- })
+ }))
}
if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
@@ -1263,3 +1831,308 @@ func TestReceiveIPv6Fragments(t *testing.T) {
})
}
}
+
+func TestInvalidIPv6Fragments(t *testing.T) {
+ const (
+ nicID = 1
+ fragmentExtHdrLen = 8
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16,
+ []buffer.View{
+ // Fragment extension header.
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8,
+ 0, 0, 0, 1}),
+ // Payload length = 16
+ payloadGen(16),
+ },
+ ),
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ NewProtocol,
+ },
+ })
+ e := channel.New(0, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
+ t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
+ t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ allowPackets int
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ allowPackets: nPackets - 1,
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ allowPackets: math.MaxInt32,
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
+ const (
+ src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ )
+ if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ if err != nil {
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
+ }
+ return rt
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if !hasEP {
+ t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep)
+ }
+ }
+
+ ep.Close()
+
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index e28c23d66..48a4c65e3 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+package ipv6
import (
"fmt"
@@ -23,9 +23,27 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
// defaultDupAddrDetectTransmits is the default number of NDP Neighbor
// Solicitation messages to send when doing Duplicate Address Detection
// for a tentative address.
@@ -33,14 +51,8 @@ const (
// Default = 1 (from RFC 4862 section 5.1)
defaultDupAddrDetectTransmits = 1
- // defaultRetransmitTimer is the default amount of time to wait between
- // sending NDP Neighbor solicitation messages.
- //
- // Default = 1s (from RFC 4861 section 10).
- defaultRetransmitTimer = time.Second
-
// defaultMaxRtrSolicitations is the default number of Router
- // Solicitation messages to send when a NIC becomes enabled.
+ // Solicitation messages to send when an IPv6 endpoint becomes enabled.
//
// Default = 3 (from RFC 4861 section 10).
defaultMaxRtrSolicitations = 3
@@ -79,16 +91,6 @@ const (
// Default = true.
defaultAutoGenGlobalAddresses = true
- // minimumRetransmitTimer is the minimum amount of time to wait between
- // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
- // not impose a minimum Retransmit Timer, but we do here to make sure
- // the messages are not sent all at once. We also come to this value
- // because in the RetransmitTimer field of a Router Advertisement, a
- // value of 0 means unspecified, so the smallest valid value is 1.
- // Note, the unit of the RetransmitTimer field in the Router
- // Advertisement is milliseconds.
- minimumRetransmitTimer = time.Millisecond
-
// minimumRtrSolicitationInterval is the minimum amount of time to wait
// between sending Router Solicitation messages. This limit is imposed
// to make sure that Router Solicitation messages are not sent all at
@@ -147,7 +149,7 @@ const (
minRegenAdvanceDuration = time.Duration(0)
// maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
- // SLAAC address regenerations in response to a NIC-local conflict.
+ // SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
maxSLAACAddrLocalRegenAttempts = 10
)
@@ -179,7 +181,7 @@ var (
// This is exported as a variable (instead of a constant) so tests
// can update it to a smaller value.
//
- // This value guarantees that a temporary address will be preferred for at
+ // This value guarantees that a temporary address is preferred for at
// least 1hr if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
@@ -189,11 +191,17 @@ var (
// This is exported as a variable (instead of a constant) so tests
// can update it to a smaller value.
//
- // This value guarantees that a temporary address will be valid for at least
+ // This value guarantees that a temporary address is valid for at least
// 2hrs if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrValidLifetime = 2 * time.Hour
)
+// NDPEndpoint is an endpoint that supports NDP.
+type NDPEndpoint interface {
+ // SetNDPConfigurations sets the NDP configurations.
+ SetNDPConfigurations(NDPConfigurations)
+}
+
// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
// NDP Router Advertisement informed the Stack about.
type DHCPv6ConfigurationFromNDPRA int
@@ -208,7 +216,7 @@ const (
// DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
//
// DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
- // will return all available configuration information.
+ // returns all available configuration information when serving addresses.
DHCPv6ManagedAddress
// DHCPv6OtherConfigurations indicates that other configuration information is
@@ -223,19 +231,18 @@ const (
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
- // OnDuplicateAddressDetectionStatus will be called when the DAD process
- // for an address (addr) on a NIC (with ID nicID) completes. resolved
- // will be set to true if DAD completed successfully (no duplicate addr
- // detected); false otherwise (addr was detected to be a duplicate on
- // the link the NIC is a part of, or it was stopped for some other
- // reason, such as the address being removed). If an error occured
- // during DAD, err will be set and resolved must be ignored.
+ // OnDuplicateAddressDetectionStatus is called when the DAD process for an
+ // address (addr) on a NIC (with ID nicID) completes. resolved is set to true
+ // if DAD completed successfully (no duplicate addr detected); false otherwise
+ // (addr was detected to be a duplicate on the link the NIC is a part of, or
+ // it was stopped for some other reason, such as the address being removed).
+ // If an error occured during DAD, err is set and resolved must be ignored.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
- // OnDefaultRouterDiscovered will be called when a new default router is
+ // OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
// router should be remembered.
//
@@ -243,56 +250,55 @@ type NDPDispatcher interface {
// is also not permitted to call into the stack.
OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
- // OnDefaultRouterInvalidated will be called when a discovered default
- // router that was remembered is invalidated.
+ // OnDefaultRouterInvalidated is called when a discovered default router that
+ // was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
- // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is
- // discovered. Implementations must return true if the newly discovered
- // on-link prefix should be remembered.
+ // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
+ // Implementations must return true if the newly discovered on-link prefix
+ // should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
- // OnOnLinkPrefixInvalidated will be called when a discovered on-link
- // prefix that was remembered is invalidated.
+ // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
+ // was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
- // OnAutoGenAddress will be called when a new prefix with its
- // autonomous address-configuration flag set has been received and SLAAC
- // has been performed. Implementations may prevent the stack from
- // assigning the address to the NIC by returning false.
+ // OnAutoGenAddress is called when a new prefix with its autonomous address-
+ // configuration flag set is received and SLAAC was performed. Implementations
+ // may prevent the stack from assigning the address to the NIC by returning
+ // false.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
- // OnAutoGenAddressDeprecated will be called when an auto-generated
- // address (as part of SLAAC) has been deprecated, but is still
- // considered valid. Note, if an address is invalidated at the same
- // time it is deprecated, the deprecation event MAY be omitted.
+ // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
+ // is deprecated, but is still considered valid. Note, if an address is
+ // invalidated at the same ime it is deprecated, the deprecation event may not
+ // be received.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
- // OnAutoGenAddressInvalidated will be called when an auto-generated
- // address (as part of SLAAC) has been invalidated.
+ // OnAutoGenAddressInvalidated is called when an auto-generated address
+ // (SLAAC) is invalidated.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
- // OnRecursiveDNSServerOption will be called when an NDP option with
- // recursive DNS servers has been received. Note, addrs may contain
- // link-local addresses.
+ // OnRecursiveDNSServerOption is called when the stack learns of DNS servers
+ // through NDP. Note, the addresses may contain link-local addresses.
//
// It is up to the caller to use the DNS Servers only for their valid
// lifetime. OnRecursiveDNSServerOption may be called for new or
@@ -304,8 +310,8 @@ type NDPDispatcher interface {
// call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
- // OnDNSSearchListOption will be called when an NDP option with a DNS
- // search list has been received.
+ // OnDNSSearchListOption is called when the stack learns of DNS search lists
+ // through NDP.
//
// It is up to the caller to use the domain names in the search list
// for only their valid lifetime. OnDNSSearchListOption may be called
@@ -314,8 +320,8 @@ type NDPDispatcher interface {
// be increased, decreased or completely invalidated when lifetime = 0.
OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
- // OnDHCPv6Configuration will be called with an updated configuration that is
- // available via DHCPv6 for a specified NIC.
+ // OnDHCPv6Configuration is called with an updated configuration that is
+ // available via DHCPv6 for the passed NIC.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
@@ -336,7 +342,7 @@ type NDPConfigurations struct {
// Must be greater than or equal to 1ms.
RetransmitTimer time.Duration
- // The number of Router Solicitation messages to send when the NIC
+ // The number of Router Solicitation messages to send when the IPv6 endpoint
// becomes enabled.
MaxRtrSolicitations uint8
@@ -351,24 +357,22 @@ type NDPConfigurations struct {
// Must be greater than or equal to 0s.
MaxRtrSolicitationDelay time.Duration
- // HandleRAs determines whether or not Router Advertisements will be
- // processed.
+ // HandleRAs determines whether or not Router Advertisements are processed.
HandleRAs bool
- // DiscoverDefaultRouters determines whether or not default routers will
- // be discovered from Router Advertisements. This configuration is
- // ignored if HandleRAs is false.
+ // DiscoverDefaultRouters determines whether or not default routers are
+ // discovered from Router Advertisements, as per RFC 4861 section 6. This
+ // configuration is ignored if HandleRAs is false.
DiscoverDefaultRouters bool
- // DiscoverOnLinkPrefixes determines whether or not on-link prefixes
- // will be discovered from Router Advertisements' Prefix Information
- // option. This configuration is ignored if HandleRAs is false.
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
+ // discovered from Router Advertisements' Prefix Information option, as per
+ // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
DiscoverOnLinkPrefixes bool
- // AutoGenGlobalAddresses determines whether or not global IPv6
- // addresses will be generated for a NIC in response to receiving a new
- // Prefix Information option with its Autonomous Address
- // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC).
+ // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
+ // SLAAC to auto-generate global SLAAC addresses in response to Prefix
+ // Information options, as per RFC 4862.
//
// Note, if an address was already generated for some unique prefix, as
// part of SLAAC, this option does not affect whether or not the
@@ -382,12 +386,12 @@ type NDPConfigurations struct {
//
// If the method used to generate the address does not support creating
// alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
- // MAC address), then no attempt will be made to resolve the conflict.
+ // MAC address), then no attempt is made to resolve the conflict.
AutoGenAddressConflictRetries uint8
// AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
- // addresses will be generated for a NIC as part of SLAAC privacy extensions,
- // RFC 4941.
+ // addresses are generated for an IPv6 endpoint as part of SLAAC privacy
+ // extensions, as per RFC 4941.
//
// Ignored if AutoGenGlobalAddresses is false.
AutoGenTempGlobalAddresses bool
@@ -426,7 +430,7 @@ func DefaultNDPConfigurations() NDPConfigurations {
}
// validate modifies an NDPConfigurations with valid values. If invalid values
-// are present in c, the corresponding default values will be used instead.
+// are present in c, the corresponding default values are used instead.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
@@ -455,8 +459,8 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
- // The NIC this ndpState is for.
- nic *NIC
+ // The IPv6 endpoint this ndpState is for.
+ ep *endpoint
// configs is the per-interface NDP configurations.
configs NDPConfigurations
@@ -469,13 +473,13 @@ type ndpState struct {
rtrSolicit struct {
// The timer used to send the next router solicitation message.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the Router Solicitation timer know that it has been stopped.
//
// Must only be read from or written to while protected by the lock of
- // the NIC this ndpState is associated with. MUST be set when the timer is
- // set.
+ // the IPv6 endpoint this ndpState is associated with. MUST be set when the
+ // timer is set.
done *bool
}
@@ -503,57 +507,57 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
// Must only be read from or written to while protected by the lock of
- // the NIC this dadState is associated with.
+ // the IPv6 endpoint this dadState is associated with.
done *bool
}
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
// The address's endpoint.
//
// Must not be nil.
- ref *referencedNetworkEndpoint
+ addressEndpoint stack.AddressEndpoint
// Has a new temporary SLAAC address already been regenerated?
regenerated bool
@@ -561,15 +565,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -583,10 +587,10 @@ type slaacPrefixState struct {
//
// May only be nil when the address is being (re-)generated. Otherwise,
// must not be nil as all SLAAC prefixes must have a stable address.
- ref *referencedNetworkEndpoint
+ addressEndpoint stack.AddressEndpoint
- // The number of times an address has been generated locally where the NIC
- // already had the generated address.
+ // The number of times an address has been generated locally where the IPv6
+ // endpoint already had the generated address.
localGenerationFailures uint8
}
@@ -594,11 +598,12 @@ type slaacPrefixState struct {
tempAddrs map[tcpip.Address]tempSLAACAddrState
// The next two fields are used by both stable and temporary addresses
- // generated for a SLAAC prefix. This is safe as only 1 address will be
- // in the generation and DAD process at any time. That is, no two addresses
- // will be generated at the same time for a given SLAAC prefix.
+ // generated for a SLAAC prefix. This is safe as only 1 address is in the
+ // generation and DAD process at any time. That is, no two addresses are
+ // generated at the same time for a given SLAAC prefix.
- // The number of times an address has been generated and added to the NIC.
+ // The number of times an address has been generated and added to the IPv6
+ // endpoint.
//
// Addresses may be regenerated in reseponse to a DAD conflicts.
generationAttempts uint8
@@ -613,16 +618,16 @@ type slaacPrefixState struct {
// This function must only be called by IPv6 addresses that are currently
// tentative.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
return tcpip.ErrAddressFamilyNotSupported
}
- if ref.getKind() != permanentTentative {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -633,45 +638,45 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// existed, we would get an error since we attempted to add a duplicate
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
- // See NIC.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ // See endpoint.addAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
if remaining == 0 {
- ref.setKind(permanent)
+ addressEndpoint.SetKind(stack.Permanent)
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
return nil
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the NIC's lock. This is effectively the same
- // as starting a goroutine but we use a timer that fires immediately so we can
- // reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
+ ndp.ep.mu.Lock()
+ defer ndp.ep.mu.Unlock()
if done {
// If we reach this point, it means that the DAD timer fired after
- // another goroutine already obtained the NIC lock and stopped DAD
- // before this function obtained the NIC lock. Simply return here and do
- // nothing further.
+ // another goroutine already obtained the IPv6 endpoint lock and stopped
+ // DAD before this function obtained the NIC lock. Simply return here and
+ // do nothing further.
return
}
- if ref.getKind() != permanentTentative {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
dadDone := remaining == 0
@@ -679,33 +684,34 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
var err *tcpip.Error
if !dadDone {
// Use the unspecified address as the source address when performing DAD.
- ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
// Do not hold the lock when sending packets which may be a long running
// task or may block link address resolution. We know this is safe
// because immediately after obtaining the lock again, we check if DAD
- // has been stopped before doing any work with the NIC. Note, DAD would be
- // stopped if the NIC was disabled or removed, or if the address was
- // removed.
- ndp.nic.mu.Unlock()
- err = ndp.sendDADPacket(addr, ref)
- ndp.nic.mu.Lock()
+ // has been stopped before doing any work with the IPv6 endpoint. Note,
+ // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
+ // the address was removed.
+ ndp.ep.mu.Unlock()
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ ndp.ep.mu.Lock()
+ addressEndpoint.DecRef()
}
if done {
// If we reach this point, it means that DAD was stopped after we released
- // the NIC's read lock and before we obtained the write lock.
+ // the IPv6 endpoint's read lock and before we obtained the write lock.
return
}
if dadDone {
// DAD has resolved.
- ref.setKind(permanent)
+ addressEndpoint.SetKind(stack.Permanent)
} else if err == nil {
// DAD is not done and we had no errors when sending the last NDP NS,
// schedule the next DAD timer.
remaining--
- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ timer.Reset(ndp.configs.RetransmitTimer)
return
}
@@ -714,16 +720,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// integrator know DAD has completed.
delete(ndp.dad, addr)
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
// If DAD resolved for a stable SLAAC address, attempt generation of a
// temporary SLAAC address.
- if dadDone && ref.configType == slaac {
+ if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */)
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
}
})
@@ -738,44 +744,50 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
// addr.
//
-// addr must be a tentative IPv6 address on ndp's NIC.
+// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
//
-// The NIC ndp belongs to MUST NOT be locked.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
// Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the NIC is not locked,
- // the NIC could have been disabled or removed by another goroutine.
+ // Since this method is required to be called when the IPv6 endpoint is not
+ // locked, the NIC could have been disabled or removed by another goroutine.
if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
return err
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmpData.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
ns.SetTargetAddress(addr)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
- NetworkHeaderParams{
+ stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- TOS: DefaultTOS,
- }, &PacketBuffer{Header: hdr},
+ }, pkt,
); err != nil {
sent.Dropped.Increment()
return err
@@ -790,11 +802,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd
// such a state forever, unless some other external event resolves the DAD
// process (receiving an NA from the true owner of addr, or an NS for addr
// (implying another node is attempting to use addr)). It is up to the caller
-// of this function to handle such a scenario. Normally, addr will be removed
-// from n right after this function returns or the address successfully
-// resolved.
+// of this function to handle such a scenario.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
dad, ok := ndp.dad[addr]
if !ok {
@@ -813,30 +823,30 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
// handleRA handles a Router Advertisement message that arrived on the NIC
// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
- // Is the NIC configured to handle RAs at all?
+ // Is the IPv6 endpoint configured to handle RAs at all?
//
// Currently, the stack does not determine router interface status on a
- // per-interface basis; it is a stack-wide configuration, so we check
- // stack's forwarding flag to determine if the NIC is a routing
- // interface.
- if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding {
+ // per-interface basis; it is a protocol-wide configuration, so we check the
+ // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
+ // packets.
+ if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
return
}
// Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
// only inform the dispatcher on configuration changes. We do nothing else
// with the information.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
var configuration DHCPv6ConfigurationFromNDPRA
switch {
case ra.ManagedAddrConfFlag():
@@ -851,11 +861,11 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if ndp.dhcpv6Configuration != configuration {
ndp.dhcpv6Configuration = configuration
- ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
}
}
- // Is the NIC configured to discover default routers?
+ // Is the IPv6 endpoint configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
rtr, ok := ndp.defaultRouters[ip]
rl := ra.RouterLifetime()
@@ -871,9 +881,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -893,20 +903,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
switch opt := opt.(type) {
case header.NDPRecursiveDNSServer:
- if ndp.nic.stack.ndpDisp == nil {
+ if ndp.ep.protocol.ndpDisp == nil {
continue
}
addrs, _ := opt.Addresses()
- ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
- if ndp.nic.stack.ndpDisp == nil {
+ if ndp.ep.protocol.ndpDisp == nil {
continue
}
domainNames, _ := opt.DomainNames()
- ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -940,7 +950,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// invalidateDefaultRouter invalidates a discovered default router.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
rtr, ok := ndp.defaultRouters[ip]
@@ -950,41 +960,41 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
// rememberDefaultRouter remembers a newly discovered default router with IPv6
// link-local address ip with lifetime rl.
//
-// The router identified by ip MUST NOT already be known by the NIC.
+// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return
}
// Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) {
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
// Informed by the integrator to not remember the router, do
// nothing further.
return
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -994,28 +1004,28 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
//
// The prefix identified by prefix MUST NOT already be known.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) {
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
// Informed by the integrator to not remember the prefix, do
// nothing further.
return
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1023,7 +1033,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
s, ok := ndp.onLinkPrefixes[prefix]
@@ -1033,12 +1043,12 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1048,7 +1058,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the on-link flag is set.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
prefix := pi.Subnet()
prefixState, ok := ndp.onLinkPrefixes[prefix]
@@ -1082,14 +1092,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1101,7 +1111,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
// not the link-local prefix and the autonomous flag is set.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
vl := pi.ValidLifetime()
pl := pi.PreferredLifetime()
@@ -1137,7 +1147,7 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// If we do not already have an address for this prefix and the valid
// lifetime is 0, no need to do anything further, as per RFC 4862
@@ -1154,15 +1164,15 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
}
- ndp.deprecateSLAACAddress(state.stableAddr.ref)
+ ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1184,24 +1194,24 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
// If the address is assigned (DAD resolved), generate a temporary address.
- if state.stableAddr.ref.getKind() == permanent {
+ if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
@@ -1210,32 +1220,27 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.slaacPrefixes[prefix] = state
}
-// addSLAACAddr adds a SLAAC address to the NIC.
+// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
// Inform the integrator that we have a new SLAAC address.
- ndpDisp := ndp.nic.stack.ndpDisp
+ ndpDisp := ndp.ep.protocol.ndpDisp
if ndpDisp == nil {
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) {
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
// Informed by the integrator not to add the address.
return nil
}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr,
- }
-
- ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
if err != nil {
- panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err))
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
- return ref
+ return addressEndpoint
}
// generateSLAACAddr generates a SLAAC address for prefix.
@@ -1244,10 +1249,10 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo
//
// Panics if the prefix is not a SLAAC prefix or it already has an address.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
- if r := state.stableAddr.ref; r != nil {
- panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix()))
}
// If we have already reached the maximum address generation attempts for the
@@ -1267,11 +1272,11 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
}
dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
- if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
- oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
dadCounter,
oIID.SecretKey,
)
@@ -1284,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
//
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
- linkAddr := ndp.nic.linkEP.LinkAddress()
+ linkAddr := ndp.ep.linkEP.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
return false
}
@@ -1303,15 +1308,15 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
PrefixLen: validPrefixLenForAutoGen,
}
- if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
break
}
state.stableAddr.localGenerationFailures++
}
- if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil {
- state.stableAddr.ref = ref
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ state.stableAddr.addressEndpoint = addressEndpoint
state.generationAttempts++
return true
}
@@ -1321,10 +1326,9 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
//
-// If generating a new address for the prefix fails, the prefix will be
-// invalidated.
+// If generating a new address for the prefix fails, the prefix is invalidated.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
@@ -1344,7 +1348,7 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
// generateTempSLAACAddr generates a new temporary SLAAC address.
//
-// If resetGenAttempts is true, the prefix's generation counter will be reset.
+// If resetGenAttempts is true, the prefix's generation counter is reset.
//
// Returns true if a new address was generated.
func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
@@ -1365,7 +1369,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
return false
}
- stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress
+ stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address
now := time.Now()
// As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
@@ -1404,7 +1408,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
return false
}
- // Attempt to generate a new address that is not already assigned to the NIC.
+ // Attempt to generate a new address that is not already assigned to the IPv6
+ // endpoint.
var generatedAddr tcpip.AddressWithPrefix
for i := 0; ; i++ {
// If we were unable to generate an address after the maximum SLAAC address
@@ -1414,7 +1419,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
- if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
break
}
}
@@ -1422,13 +1427,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
// address with a zero preferred lifetime. The checks above ensure this
// so we know the address is not deprecated.
- ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */)
- if ref == nil {
+ addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */)
+ if addressEndpoint == nil {
return false
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1439,9 +1444,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
}
- ndp.deprecateSLAACAddress(tempAddrState.ref)
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1454,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1477,13 +1482,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
ndp.slaacPrefixes[prefix] = prefixState
}),
- createdAt: now,
- ref: ref,
+ createdAt: now,
+ addressEndpoint: addressEndpoint,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1493,7 +1498,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
@@ -1508,26 +1513,26 @@ func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttemp
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
// If the preferred lifetime is zero, then the prefix should be deprecated.
deprecated := pl == 0
if deprecated {
- ndp.deprecateSLAACAddress(prefixState.stableAddr.ref)
+ ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint)
} else {
- prefixState.stableAddr.ref.deprecated = false
+ prefixState.stableAddr.addressEndpoint.SetDeprecated(false)
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1546,9 +1551,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1569,20 +1574,20 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
// If DAD is not yet complete on the stable address, there is no need to do
// work with temporary addresses.
- if prefixState.stableAddr.ref.getKind() != permanent {
+ if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent {
return
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1596,14 +1601,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1616,17 +1621,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
- ndp.deprecateSLAACAddress(tempAddrState.ref)
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
} else {
- tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.addressEndpoint.SetDeprecated(false)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1637,7 +1642,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1647,8 +1652,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// due to an update in preferred lifetime.
//
// If each temporay address has already been regenerated, no new temporary
- // address will be generated. To ensure continuation of temporary SLAAC
- // addresses, we manually try to regenerate an address here.
+ // address is generated. To ensure continuation of temporary SLAAC addresses,
+ // we manually try to regenerate an address here.
if len(regenForAddr) != 0 || allAddressesRegenerated {
// Reset the generation attempts counter as we are starting the generation
// of a new address for the SLAAC prefix.
@@ -1659,57 +1664,58 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
}
-// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
-// dispatcher that ref has been deprecated.
+// deprecateSLAACAddress marks the address as deprecated and notifies the NDP
+// dispatcher that address has been deprecated.
//
-// deprecateSLAACAddress does nothing if ref is already deprecated.
+// deprecateSLAACAddress does nothing if the address is already deprecated.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
- if ref.deprecated {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) {
+ if addressEndpoint.Deprecated() {
return
}
- ref.deprecated = true
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix())
+ addressEndpoint.SetDeprecated(true)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
// invalidateSLAACPrefix invalidates a SLAAC prefix.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
- if r := state.stableAddr.ref; r != nil {
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
// Since we are already invalidating the prefix, do not invalidate the
// prefix when removing the address.
- if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil {
- panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err))
+ if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
}
}
-
- ndp.cleanupSLAACPrefixResources(prefix, state)
}
// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
// resources.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
prefix := addr.Subnet()
state, ok := ndp.slaacPrefixes[prefix]
- if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress {
+ if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address {
return
}
if !invalidatePrefix {
// If the prefix is not being invalidated, disassociate the address from the
// prefix and do nothing further.
- state.stableAddr.ref = nil
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
ndp.slaacPrefixes[prefix] = state
return
}
@@ -1717,31 +1723,34 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
// Invalidate all temporary addresses.
for tempAddr, tempAddrState := range state.tempAddrs {
ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
}
- state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ if state.stableAddr.addressEndpoint != nil {
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
+ }
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
// Since we are already invalidating the address, do not invalidate the
// address when removing the address.
- if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil {
- panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err))
+ if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
}
ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
@@ -1750,10 +1759,10 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
// SLAAC address's resources from ndp.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
if !invalidateAddr {
@@ -1775,37 +1784,31 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
-// The NIC that ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.addressEndpoint.DecRef()
+ tempAddrState.addressEndpoint = nil
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
-// cleanupState cleans up ndp's state.
-//
-// If hostOnly is true, then only host-specific state will be cleaned up.
-//
-// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
-// transitioning from a host to a router. This function will invalidate all
-// discovered on-link prefixes, discovered routers, and auto-generated
-// addresses.
+// removeSLAACAddresses removes all SLAAC addresses.
//
-// If hostOnly is true, then the link-local auto-generated address will not be
-// invalidated as routers are also expected to generate a link-local address.
+// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
//
-// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupState(hostOnly bool) {
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- linkLocalPrefixes := 0
+ var linkLocalPrefixes int
for prefix, state := range ndp.slaacPrefixes {
// RFC 4862 section 5 states that routers are also expected to generate a
// link-local address so we do not invalidate them if we are cleaning up
// host-only state.
- if hostOnly && prefix == linkLocalSubnet {
+ if keepLinkLocal && prefix == linkLocalSubnet {
linkLocalPrefixes++
continue
}
@@ -1816,6 +1819,21 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
}
+}
+
+// cleanupState cleans up ndp's state.
+//
+// If hostOnly is true, then only host-specific state is cleaned up.
+//
+// This function invalidates all discovered on-link prefixes, discovered
+// routers, and auto-generated addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address aren't
+// invalidated as routers are also expected to generate a link-local address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
for prefix := range ndp.onLinkPrefixes {
ndp.invalidateOnLinkPrefix(prefix)
@@ -1839,7 +1857,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
// 6.3.7. If routers are already being solicited, this function does nothing.
//
-// The NIC ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
if ndp.rtrSolicit.timer != nil {
// We are already soliciting routers.
@@ -1860,27 +1878,37 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
- ndp.nic.mu.Lock()
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
+ ndp.ep.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
- // goroutine already obtained the NIC lock and stopped solicitations.
- // Simply return here and do nothing further.
- ndp.nic.mu.Unlock()
+ // goroutine already obtained the IPv6 endpoint lock and stopped
+ // solicitations. Simply return here and do nothing further.
+ ndp.ep.mu.Unlock()
return
}
// As per RFC 4861 section 4.1, the source of the RS is an address assigned
// to the sending interface, or the unspecified address if no address is
// assigned to the sending interface.
- ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
- if ref == nil {
- ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
+ if addressEndpoint == nil {
+ // Incase this ends up creating a new temporary address, we need to hold
+ // onto the endpoint until a route is obtained. If we decrement the
+ // reference count before obtaing a route, the address's resources would
+ // be released and attempting to obtain a route after would fail. Once a
+ // route is obtainted, it is safe to decrement the reference count since
+ // obtaining a route increments the address's reference count.
+ addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
}
- ndp.nic.mu.Unlock()
+ ndp.ep.mu.Unlock()
- localAddr := ref.ep.ID().LocalAddress
- r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ addressEndpoint.DecRef()
+ if err != nil {
+ return
+ }
defer r.Release()
// Route should resolve immediately since
@@ -1888,15 +1916,16 @@ func (ndp *ndpState) startSolicitingRouters() {
// remote link address can be calculated without a resolution process.
if c, err := r.Resolve(nil); err != nil {
// Do not consider the NIC being unknown or disabled as a fatal error.
- // Since this method is required to be called when the NIC is not locked,
- // the NIC could have been disabled or removed by another goroutine.
+ // Since this method is required to be called when the IPv6 endpoint is
+ // not locked, the IPv6 endpoint could have been disabled or removed by
+ // another goroutine.
if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
return
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1913,23 +1942,26 @@ func (ndp *ndpState) startSolicitingRouters() {
}
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
- pkt := header.ICMPv6(hdr.Prepend(payloadSize))
- pkt.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(pkt.NDPPayload())
+ icmpData := header.ICMPv6(buffer.NewView(payloadSize))
+ icmpData.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(icmpData.NDPPayload())
rs.Options().Serialize(optsSerializer)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
- NetworkHeaderParams{
+ stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- TOS: DefaultTOS,
- }, &PacketBuffer{Header: hdr},
+ }, pkt,
); err != nil {
sent.Dropped.Increment()
- log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
@@ -1937,19 +1969,19 @@ func (ndp *ndpState) startSolicitingRouters() {
remaining--
}
- ndp.nic.mu.Lock()
+ ndp.ep.mu.Lock()
if done || remaining == 0 {
ndp.rtrSolicit.timer = nil
ndp.rtrSolicit.done = nil
} else if ndp.rtrSolicit.timer != nil {
// Note, we need to explicitly check to make sure that
// the timer field is not nil because if it was nil but
- // we still reached this point, then we know the NIC
+ // we still reached this point, then we know the IPv6 endpoint
// was requested to stop soliciting routers so we don't
// need to send the next Router Solicitation message.
ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
}
- ndp.nic.mu.Unlock()
+ ndp.ep.mu.Unlock()
})
}
@@ -1957,7 +1989,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// stopSolicitingRouters stops soliciting routers. If routers are not currently
// being solicited, this function does nothing.
//
-// The NIC ndp belongs to MUST be locked.
+// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
if ndp.rtrSolicit.timer == nil {
// Nothing to do.
@@ -1973,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() {
// initializeTempAddrState initializes state related to temporary SLAAC
// addresses.
func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID())
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 64239ce9a..25464a03a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -17,7 +17,9 @@ package ipv6
import (
"strings"
"testing"
+ "time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -30,12 +32,13 @@ import (
// setupStackAndEndpoint creates a stack with a single NIC with a link-local
// address llladdr and an IPv6 endpoint to a remote with link-local address
// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
@@ -63,14 +66,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
}
+ t.Cleanup(ep.Close)
return s, ep
}
+var _ NDPDispatcher = (*testNDPDispatcher)(nil)
+
+// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
+type testNDPDispatcher struct {
+ addr tcpip.Address
+}
+
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+ t.addr = addr
+ return true
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+ t.addr = addr
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) {
+}
+
+func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
+ var ndpDisp testNDPDispatcher
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
+ }
+
+ ipv6EP := ep.(*endpoint)
+ ipv6EP.mu.Lock()
+ ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.Unlock()
+
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+
+ ndpDisp.addr = ""
+ ndpEP := ep.(stack.NDPEndpoint)
+ ndpEP.InvalidateDefaultRouter(lladdr1)
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+}
+
// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -100,7 +183,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
e := channel.New(0, 1280, linkAddr0)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -136,9 +219,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
if linkAddr != test.expectedLinkAddr {
@@ -174,6 +257,123 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
}
}
+// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NS message with the Source Link Layer Address
+// option results in a new entry in the link address cache for the sender of
+// the message.
+func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ if !ok {
+ t.Fatalf("expected a neighbor entry for %q", lladdr1)
+ }
+ if neigh.LinkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr)
+ }
+ if neigh.State != stack.Stale {
+ t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+
+ if ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+ }
+ })
+ }
+}
+
func TestNeighorSolicitationResponse(t *testing.T) {
const nicID = 1
nicAddr := lladdr0
@@ -183,6 +383,20 @@ func TestNeighorSolicitationResponse(t *testing.T) {
remoteLinkAddr0 := linkAddr1
remoteLinkAddr1 := linkAddr2
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
name string
nsOpts header.NDPOptionsSerializer
@@ -341,86 +555,92 @@ func TestNeighorSolicitationResponse(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(1, 1280, nicLinkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(nicAddr)
- opts := ns.Options()
- opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
- })
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
- if test.nsInvalid {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
- if p, got := e.Read(); got {
- t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
- }
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
- // If we expected the NS to be invalid, we have nothing else to check.
- return
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
- p, got := e.Read()
- if !got {
- t.Fatal("expected an NDP NA response")
- }
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
}
-
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(test.naSrc),
- checker.DstAddr(test.naDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNA(
- checker.NDPNASolicitedFlag(test.naSolicited),
- checker.NDPNATargetAddress(nicAddr),
- checker.NDPNAOptions([]header.NDPOption{
- header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
- }),
- ))
})
}
}
@@ -461,7 +681,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
e := channel.New(0, 1280, linkAddr0)
if err := s.CreateNIC(nicID, e); err != nil {
@@ -497,9 +717,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
- })
+ }))
linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
if linkAddr != test.expectedLinkAddr {
@@ -535,201 +755,385 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
}
-func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
- }
-
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
- nextHdr := uint8(header.ICMPv6ProtocolNumber)
- var extensions buffer.View
- if atomicFragment {
- extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
- extensions[0] = nextHdr
- nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
- }
-
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions)))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + len(extensions)),
- NextHeader: nextHdr,
- HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
- if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
- t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
- }
- ep.HandlePacket(r, &stack.PacketBuffer{
- NetworkHeader: buffer.View(ip),
- Data: payload.ToVectorisedView(),
- })
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
+// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
+// that receiving a valid NDP NA message with the Target Link Layer Address
+// option does not result in a new entry in the neighbor cache for the target
+// of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+ const nicID = 1
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ tests := []struct {
+ name string
+ optsBuf []byte
+ isValid bool
}{
{
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ isValid: true,
},
{
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
},
{
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
},
{
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
},
},
}
- subTests := []struct {
- name string
- atomicFragment bool
- hopLimit uint8
- code uint8
- valid bool
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ neighbors, err := s.Neighbors(nicID)
+ if err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ }
+
+ neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
+ for _, n := range neighbors {
+ if existing, ok := neighborByAddr[n.Addr]; ok {
+ if diff := cmp.Diff(existing, n); diff != "" {
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
+ }
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ }
+ neighborByAddr[n.Addr] = n
+ }
+
+ if neigh, ok := neighborByAddr[lladdr1]; ok {
+ t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ }
+
+ if test.isValid {
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNDPValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
}{
{
- name: "Valid",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: true,
+ name: "linkAddrCache",
+ useNeighborCache: false,
},
{
- name: "Fragmented",
- atomicFragment: true,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid hop limit",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit - 1,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid ICMPv6 code",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 1,
- valid: false,
+ name: "neighborCache",
+ useNeighborCache: true,
},
}
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- for _, test := range subTests {
- t.Run(test.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
- // Rx count of the NDP message should initially be 0.
- if got := typStat.Value(); got != 0 {
- t.Errorf("got %s = %d, want = 0", typ.name, got)
- }
+ return s, ep, r
+ }
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
+ if atomicFragment {
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
- if t.Failed() {
- t.FailNow()
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ Data: payload.ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload) + len(extensions)),
+ NextHeader: nextHdr,
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, pkt)
+ }
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
- // Rx count of the NDP packet should have increased.
- if got := typStat.Value(); got != 1 {
- t.Errorf("got %s = %d, want = 1", typ.name, got)
- }
+ var sllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
- want := uint64(0)
- if !test.valid {
- // Invalid count should have increased.
- want = 1
- }
- if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
+ }{
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ routerOnly: true,
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ extraData: sllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code header.ICMPv6Code
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
+
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
}
- })
+
+ t.Run(name, func(t *testing.T) {
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+
+ // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ if got := routerOnly.Value(); got != 0 {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+
+ want = 0
+ if test.valid && !isRouter && typ.routerOnly {
+ // RouterOnlyPacketsReceivedByHost count should have increased.
+ want = 1
+ }
+ if got := routerOnly.Value(); got != want {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ }
+
+ })
+ }
+ })
+ }
}
})
}
+
}
// TestRouterAdvertValidation tests that when the NIC is configured to handle
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
func TestRouterAdvertValidation(t *testing.T) {
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
+
tests := []struct {
name string
src tcpip.Address
hopLimit uint8
- code uint8
+ code header.ICMPv6Code
ndpPayload []byte
expectedSuccess bool
}{
@@ -846,61 +1250,67 @@ func TestRouterAdvertValidation(t *testing.T) {
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: stackTyp.useNeighborCache,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
- e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
}
})
}
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
new file mode 100644
index 000000000..c9e57dc0d
--- /dev/null
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ srcs = [
+ "testutil.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
new file mode 100644
index 000000000..7cc52985e
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// MockLinkEndpoint is an endpoint used for testing, it stores packets written
+// to it and can mock errors.
+type MockLinkEndpoint struct {
+ // WrittenPackets is where packets written to the endpoint are stored.
+ WrittenPackets []*stack.PacketBuffer
+
+ mtu uint32
+ err *tcpip.Error
+ allowPackets int
+}
+
+// NewMockLinkEndpoint creates a new MockLinkEndpoint.
+//
+// err is the error that will be returned once allowPackets packets are written
+// to the endpoint.
+func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+ return &MockLinkEndpoint{
+ mtu: mtu,
+ err: err,
+ allowPackets: allowPackets,
+ }
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var n int
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+ return n, err
+ }
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*MockLinkEndpoint) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*MockLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+ var views buffer.VectorisedView
+
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ if _, err := rand.Read(newView); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ views.AppendView(newView)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+ Data: views,
+ })
+ pkt.NetworkProtocolNumber = proto
+ if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
+}
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index f6d592eb5..d87193650 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -400,7 +400,11 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) {
+//
+// An optional testPort closure can be passed in which if provided will be used
+// to test if the picked port can be used. The function should return true if
+// the port is safe to use, false otherwise.
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -412,12 +416,23 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
return 0, tcpip.ErrPortInUse
}
+ if testPort != nil && !testPort(port) {
+ s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst)
+ return 0, tcpip.ErrPortInUse
+ }
return port, nil
}
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil
+ if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) {
+ return false, nil
+ }
+ if testPort != nil && !testPort(p) {
+ s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst)
+ return false, nil
+ }
+ return true, nil
})
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 58db5868c..4bc949fd8 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -332,7 +332,7 @@ func TestPortReservation(t *testing.T) {
pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */)
if err != test.want {
t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want)
}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 0ab089208..51d428049 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -127,8 +127,8 @@ func main() {
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
mtu, err := rawfile.GetMTU(tunName)
@@ -182,7 +182,7 @@ func main() {
if terr == tcpip.ErrConnectStarted {
fmt.Println("Connect is pending...")
<-notifyCh
- terr = ep.GetSockOpt(tcpip.ErrorOption{})
+ terr = ep.LastError()
}
wq.EventUnregister(&waitEntry)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 9e37cab18..8e0ee1cd7 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -112,8 +112,8 @@ func main() {
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
mtu, err := rawfile.GetMTU(tunName)
@@ -188,7 +188,7 @@ func main() {
defer wq.EventUnregister(&waitEntry)
for {
- n, wq, err := ep.Accept()
+ n, wq, err := ep.Accept(nil)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index e65c731c2..2eaeab779 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -16,6 +16,18 @@ go_template_instance(
)
go_template_instance(
+ name = "neighbor_entry_list",
+ out = "neighbor_entry_list.go",
+ package = "stack",
+ prefix = "neighborEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*neighborEntry",
+ "Linker": "*neighborEntry",
+ },
+)
+
+go_template_instance(
name = "packet_buffer_list",
out = "packet_buffer_list.go",
package = "stack",
@@ -27,20 +39,38 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tuple_list",
+ out = "tuple_list.go",
+ package = "stack",
+ prefix = "tuple",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*tuple",
+ "Linker": "*tuple",
+ },
+)
+
go_library(
name = "stack",
srcs = [
+ "addressable_endpoint_state.go",
"conntrack.go",
- "dhcpv6configurationfromndpra_string.go",
"forwarder.go",
+ "headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
+ "iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
- "ndp.go",
+ "neighbor_cache.go",
+ "neighbor_entry.go",
+ "neighbor_entry_list.go",
+ "neighborstate_string.go",
"nic.go",
+ "nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
"rand.go",
@@ -50,6 +80,7 @@ go_library(
"stack_global_state.go",
"stack_options.go",
"transport_demuxer.go",
+ "tuple_list.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -74,7 +105,9 @@ go_test(
name = "stack_x_test",
size = "medium",
srcs = [
+ "addressable_endpoint_state_test.go",
"ndp_test.go",
+ "nud_test.go",
"stack_test.go",
"transport_demuxer_test.go",
"transport_test.go",
@@ -83,19 +116,22 @@ go_test(
deps = [
":stack",
"//pkg/rand",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
@@ -105,7 +141,10 @@ go_test(
srcs = [
"forwarder_test.go",
"linkaddrcache_test.go",
+ "neighbor_cache_test.go",
+ "neighbor_entry_test.go",
"nic_test.go",
+ "packet_buffer_test.go",
],
library = ":stack",
deps = [
@@ -113,6 +152,9 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
new file mode 100644
index 000000000..4d3acab96
--- /dev/null
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -0,0 +1,753 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
+var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
+
+// AddressableEndpointState is an implementation of an AddressableEndpoint.
+type AddressableEndpointState struct {
+ networkEndpoint NetworkEndpoint
+
+ // Lock ordering (from outer to inner lock ordering):
+ //
+ // AddressableEndpointState.mu
+ // addressState.mu
+ mu struct {
+ sync.RWMutex
+
+ endpoints map[tcpip.Address]*addressState
+ primary []*addressState
+
+ // groups holds the mapping between group addresses and the number of times
+ // they have been joined.
+ groups map[tcpip.Address]uint32
+ }
+}
+
+// Init initializes the AddressableEndpointState with networkEndpoint.
+//
+// Must be called before calling any other function on m.
+func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
+ a.networkEndpoint = networkEndpoint
+
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.endpoints = make(map[tcpip.Address]*addressState)
+ a.mu.groups = make(map[tcpip.Address]uint32)
+}
+
+// ReadOnlyAddressableEndpointState provides read-only access to an
+// AddressableEndpointState.
+type ReadOnlyAddressableEndpointState struct {
+ inner *AddressableEndpointState
+}
+
+// AddrOrMatching returns an endpoint for the passed address that is consisdered
+// bound to the wrapped AddressableEndpointState.
+//
+// If addr is an exact match with an existing address, that address is returned.
+// Otherwise, f is called with each address and the address that f returns true
+// for is returned.
+//
+// Returns nil of no address matches.
+func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ if ep, ok := m.inner.mu.endpoints[addr]; ok {
+ if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
+ return ep
+ }
+ }
+
+ for _, ep := range m.inner.mu.endpoints {
+ if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
+ return ep
+ }
+ }
+
+ return nil
+}
+
+// Lookup returns the AddressEndpoint for the passed address.
+//
+// Returns nil if the passed address is not associated with the
+// AddressableEndpointState.
+func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ ep, ok := m.inner.mu.endpoints[addr]
+ if !ok {
+ return nil
+ }
+ return ep
+}
+
+// ForEach calls f for each address pair.
+//
+// If f returns false, f is no longer be called.
+func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+
+ for _, ep := range m.inner.mu.endpoints {
+ if !f(ep) {
+ return
+ }
+ }
+}
+
+// ForEachPrimaryEndpoint calls f for each primary address.
+//
+// If f returns false, f is no longer be called.
+func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
+ m.inner.mu.RLock()
+ defer m.inner.mu.RUnlock()
+ for _, ep := range m.inner.mu.primary {
+ f(ep)
+ }
+}
+
+// ReadOnly returns a readonly reference to a.
+func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
+ return ReadOnlyAddressableEndpointState{inner: a}
+}
+
+func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.releaseAddressStateLocked(addrState)
+}
+
+// releaseAddressState removes addrState from s's address state (primary and endpoints list).
+//
+// Preconditions: a.mu must be write locked.
+func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) {
+ oldPrimary := a.mu.primary
+ for i, s := range a.mu.primary {
+ if s == addrState {
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ oldPrimary[len(oldPrimary)-1] = nil
+ break
+ }
+ }
+ delete(a.mu.endpoints, addrState.addr.Address)
+}
+
+// AddAndAcquirePermanentAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil, err
+ }
+ return ep, err
+}
+
+// AddAndAcquireTemporaryAddress adds a temporary address.
+//
+// Returns tcpip.ErrDuplicateAddress if the address exists.
+//
+// The temporary address's endpoint is acquired and returned.
+func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil, err
+ }
+ return ep, err
+}
+
+// addAndAcquireAddressLocked adds, acquires and returns a permanent or
+// temporary address.
+//
+// If the addressable endpoint already has the address in a non-permanent state,
+// and addAndAcquireAddressLocked is adding a permanent address, that address is
+// promoted in place and its properties set to the properties provided. If the
+// address already exists in any other state, then tcpip.ErrDuplicateAddress is
+// returned, regardless the kind of address that is being added.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) {
+ // attemptAddToPrimary is false when the address is already in the primary
+ // address list.
+ attemptAddToPrimary := true
+ addrState, ok := a.mu.endpoints[addr.Address]
+ if ok {
+ if !permanent {
+ // We are adding a non-permanent address but the address exists. No need
+ // to go any further since we can only promote existing temporary/expired
+ // addresses to permanent.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ addrState.mu.Lock()
+ if addrState.mu.kind.IsPermanent() {
+ addrState.mu.Unlock()
+ // We are adding a permanent address but a permanent address already
+ // exists.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ if addrState.mu.refs == 0 {
+ panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr))
+ }
+
+ // We now promote the address.
+ for i, s := range a.mu.primary {
+ if s == addrState {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ // The address is already in the primary address list.
+ attemptAddToPrimary = false
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ // The address is already first in the primary address list.
+ attemptAddToPrimary = false
+ } else {
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ }
+ case NeverPrimaryEndpoint:
+ a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
+ default:
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ }
+ break
+ }
+ }
+ }
+
+ if addrState == nil {
+ addrState = &addressState{
+ addressableEndpointState: a,
+ addr: addr,
+ }
+ a.mu.endpoints[addr.Address] = addrState
+ addrState.mu.Lock()
+ // We never promote an address to temporary - it can only be added as such.
+ // If we are actaully adding a permanent address, it is promoted below.
+ addrState.mu.kind = Temporary
+ }
+
+ // At this point we have an address we are either promoting from an expired or
+ // temporary address to permanent, promoting an expired address to temporary,
+ // or we are adding a new temporary or permanent address.
+ //
+ // The address MUST be write locked at this point.
+ defer addrState.mu.Unlock()
+
+ if permanent {
+ if addrState.mu.kind.IsPermanent() {
+ panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr))
+ }
+
+ // Primary addresses are biased by 1.
+ addrState.mu.refs++
+ addrState.mu.kind = Permanent
+ }
+ // Acquire the address before returning it.
+ addrState.mu.refs++
+ addrState.mu.deprecated = deprecated
+ addrState.mu.configType = configType
+
+ if attemptAddToPrimary {
+ switch peb {
+ case NeverPrimaryEndpoint:
+ case CanBePrimaryEndpoint:
+ a.mu.primary = append(a.mu.primary, addrState)
+ case FirstPrimaryEndpoint:
+ if cap(a.mu.primary) == len(a.mu.primary) {
+ a.mu.primary = append([]*addressState{addrState}, a.mu.primary...)
+ } else {
+ // Shift all the endpoints by 1 to make room for the new address at the
+ // front. We could have just created a new slice but this saves
+ // allocations when the slice has capacity for the new address.
+ primaryCount := len(a.mu.primary)
+ a.mu.primary = append(a.mu.primary, nil)
+ if n := copy(a.mu.primary[1:], a.mu.primary); n != primaryCount {
+ panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount))
+ }
+ a.mu.primary[0] = addrState
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ }
+ }
+
+ return addrState, nil
+}
+
+// RemovePermanentAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if _, ok := a.mu.groups[addr]; ok {
+ panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
+ }
+
+ return a.removePermanentAddressLocked(addr)
+}
+
+// removePermanentAddressLocked is like RemovePermanentAddress but with locking
+// requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+ addrState, ok := a.mu.endpoints[addr]
+ if !ok {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return a.removePermanentEndpointLocked(addrState)
+}
+
+// RemovePermanentEndpoint removes the passed endpoint if it is associated with
+// a and permanent.
+func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error {
+ addrState, ok := ep.(*addressState)
+ if !ok || addrState.addressableEndpointState != a {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ return a.removePermanentEndpointLocked(addrState)
+}
+
+// removePermanentAddressLocked is like RemovePermanentAddress but with locking
+// requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error {
+ if !addrState.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ addrState.SetKind(PermanentExpired)
+ a.decAddressRefLocked(addrState)
+ return nil
+}
+
+// decAddressRef decrements the address's reference count and releases it once
+// the reference count hits 0.
+func (a *AddressableEndpointState) decAddressRef(addrState *addressState) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.decAddressRefLocked(addrState)
+}
+
+// decAddressRefLocked is like decAddressRef but with locking requirements.
+//
+// Precondition: a.mu must be write locked.
+func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) {
+ addrState.mu.Lock()
+ defer addrState.mu.Unlock()
+
+ if addrState.mu.refs == 0 {
+ panic(fmt.Sprintf("attempted to decrease ref count for AddressEndpoint w/ addr = %s when it is already released", addrState.addr))
+ }
+
+ addrState.mu.refs--
+
+ if addrState.mu.refs != 0 {
+ return
+ }
+
+ // A non-expired permanent address must not have its reference count dropped
+ // to 0.
+ if addrState.mu.kind.IsPermanent() {
+ panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.mu.kind))
+ }
+
+ a.releaseAddressStateLocked(addrState)
+}
+
+// MainAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.GetKind() == Permanent
+ })
+ if ep == nil {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addr := ep.AddressWithPrefix()
+ a.decAddressRefLocked(ep)
+ return addr
+}
+
+// acquirePrimaryAddressRLocked returns an acquired primary address that is
+// valid according to isValid.
+//
+// Precondition: e.mu must be read locked
+func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState {
+ var deprecatedEndpoint *addressState
+ for _, ep := range a.mu.primary {
+ if !isValid(ep) {
+ continue
+ }
+
+ if !ep.Deprecated() {
+ if ep.IncRef() {
+ // ep is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ a.decAddressRefLocked(deprecatedEndpoint)
+ deprecatedEndpoint = nil
+ }
+
+ return ep
+ }
+ } else if deprecatedEndpoint == nil && ep.IncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of
+ // ep in case a doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, ep's reference count
+ // will be decremented.
+ deprecatedEndpoint = ep
+ }
+ }
+
+ return deprecatedEndpoint
+}
+
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if addrState, ok := a.mu.endpoints[localAddr]; ok {
+ if !addrState.IsAssigned(allowTemp) {
+ return nil
+ }
+
+ if !addrState.IncRef() {
+ panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr))
+ }
+
+ return addrState
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ if err != nil {
+ // addAndAcquireAddressLocked only returns an error if the address is
+ // already assigned but we just checked above if the address exists so we
+ // expect no error.
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ }
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since addAndAcquireAddressLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil
+ }
+ return ep
+}
+
+// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.IsAssigned(allowExpired)
+ })
+
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
+ return nil
+ }
+
+ return ep
+}
+
+// PrimaryAddresses implements AddressableEndpoint.
+func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ var addrs []tcpip.AddressWithPrefix
+ for _, ep := range a.mu.primary {
+ // Don't include tentative, expired or temporary endpoints
+ // to avoid confusion and prevent the caller from using
+ // those.
+ switch ep.GetKind() {
+ case PermanentTentative, PermanentExpired, Temporary:
+ continue
+ }
+
+ addrs = append(addrs, ep.AddressWithPrefix())
+ }
+
+ return addrs
+}
+
+// PermanentAddresses implements AddressableEndpoint.
+func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ var addrs []tcpip.AddressWithPrefix
+ for _, ep := range a.mu.endpoints {
+ if !ep.GetKind().IsPermanent() {
+ continue
+ }
+
+ addrs = append(addrs, ep.AddressWithPrefix())
+ }
+
+ return addrs
+}
+
+// JoinGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ joins, ok := a.mu.groups[group]
+ if !ok {
+ ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
+ if err != nil {
+ return false, err
+ }
+ // We have no need for the address endpoint.
+ a.decAddressRefLocked(ep)
+ }
+
+ a.mu.groups[group] = joins + 1
+ return !ok, nil
+}
+
+// LeaveGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ joins, ok := a.mu.groups[group]
+ if !ok {
+ return false, tcpip.ErrBadLocalAddress
+ }
+
+ if joins == 1 {
+ a.removeGroupAddressLocked(group)
+ delete(a.mu.groups, group)
+ return true, nil
+ }
+
+ a.mu.groups[group] = joins - 1
+ return false, nil
+}
+
+// IsInGroup implements GroupAddressableEndpoint.
+func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ _, ok := a.mu.groups[group]
+ return ok
+}
+
+func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
+ if err := a.removePermanentAddressLocked(group); err != nil {
+ // removePermanentEndpointLocked would only return an error if group is
+ // not bound to the addressable endpoint, but we know it MUST be assigned
+ // since we have group in our map of groups.
+ panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
+ }
+}
+
+// Cleanup forcefully leaves all groups and removes all permanent addresses.
+func (a *AddressableEndpointState) Cleanup() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ for group := range a.mu.groups {
+ a.removeGroupAddressLocked(group)
+ }
+ a.mu.groups = make(map[tcpip.Address]uint32)
+
+ for _, ep := range a.mu.endpoints {
+ // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
+ // not a permanent address.
+ if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err))
+ }
+ }
+}
+
+var _ AddressEndpoint = (*addressState)(nil)
+
+// addressState holds state for an address.
+type addressState struct {
+ addressableEndpointState *AddressableEndpointState
+ addr tcpip.AddressWithPrefix
+
+ // Lock ordering (from outer to inner lock ordering):
+ //
+ // AddressableEndpointState.mu
+ // addressState.mu
+ mu struct {
+ sync.RWMutex
+
+ refs uint32
+ kind AddressKind
+ configType AddressConfigType
+ deprecated bool
+ }
+}
+
+// AddressWithPrefix implements AddressEndpoint.
+func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
+ return a.addr
+}
+
+// GetKind implements AddressEndpoint.
+func (a *addressState) GetKind() AddressKind {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.kind
+}
+
+// SetKind implements AddressEndpoint.
+func (a *addressState) SetKind(kind AddressKind) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.kind = kind
+}
+
+// IsAssigned implements AddressEndpoint.
+func (a *addressState) IsAssigned(allowExpired bool) bool {
+ if !a.addressableEndpointState.networkEndpoint.Enabled() {
+ return false
+ }
+
+ switch a.GetKind() {
+ case PermanentTentative:
+ return false
+ case PermanentExpired:
+ return allowExpired
+ default:
+ return true
+ }
+}
+
+// IncRef implements AddressEndpoint.
+func (a *addressState) IncRef() bool {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.mu.refs == 0 {
+ return false
+ }
+
+ a.mu.refs++
+ return true
+}
+
+// DecRef implements AddressEndpoint.
+func (a *addressState) DecRef() {
+ a.addressableEndpointState.decAddressRef(a)
+}
+
+// ConfigType implements AddressEndpoint.
+func (a *addressState) ConfigType() AddressConfigType {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.configType
+}
+
+// SetDeprecated implements AddressEndpoint.
+func (a *addressState) SetDeprecated(d bool) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.mu.deprecated = d
+}
+
+// Deprecated implements AddressEndpoint.
+func (a *addressState) Deprecated() bool {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ return a.mu.deprecated
+}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
new file mode 100644
index 000000000..26787d0a3
--- /dev/null
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -0,0 +1,77 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
+// endpoint state removes permanent addresses and leaves groups.
+func TestAddressableEndpointStateCleanup(t *testing.T) {
+ var ep fakeNetworkEndpoint
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ var s stack.AddressableEndpointState
+ s.Init(&ep)
+
+ addr := tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ }
+
+ {
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ }
+ // We don't need the address endpoint.
+ ep.DecRef()
+ }
+ {
+ ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
+ if ep == nil {
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address)
+ }
+ ep.DecRef()
+ }
+
+ group := tcpip.Address("\x02")
+ if added, err := s.JoinGroup(group); err != nil {
+ t.Fatalf("s.JoinGroup(%s): %s", group, err)
+ } else if !added {
+ t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
+ }
+ if !s.IsInGroup(group) {
+ t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
+ }
+
+ s.Cleanup()
+ {
+ ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
+ }
+ }
+ if s.IsInGroup(group) {
+ t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ }
+}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index af9c325ca..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -15,9 +15,12 @@
package stack
import (
+ "encoding/binary"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
@@ -30,6 +33,10 @@ import (
//
// Currently, only TCP tracking is supported.
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
+
// Direction of the tuple.
type direction int
@@ -42,13 +49,19 @@ const (
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
// tuple holds a connection's identifying and manipulating data in one
// direction. It is immutable.
+//
+// +stateify savable
type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
+
tupleID
// conn is the connection tracking entry this tuple belongs to.
@@ -61,6 +74,8 @@ type tuple struct {
// tupleID uniquely identifies a connection in one direction. It currently
// contains enough information to distinguish between any TCP or UDP
// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
type tupleID struct {
srcAddr tcpip.Address
srcPort uint16
@@ -83,6 +98,8 @@ func (ti tupleID) reply() tupleID {
}
// conn is a tracked connection.
+//
+// +stateify savable
type conn struct {
// original is the tuple in original direction. It is immutable.
original tuple
@@ -97,36 +114,98 @@ type conn struct {
// update the state of tcb. It is immutable.
tcbHook Hook
- // mu protects tcb.
- mu sync.Mutex
-
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
// tcb is TCB control block. It is used to keep track of states
// of tcp connection and is protected by mu.
tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
+}
+
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
+
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
}
// ConnTrack tracks all connections created for NAT rules. Most users are
-// expected to only call handlePacket and createConnFor.
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
type ConnTrack struct {
- // mu protects conns.
- mu sync.RWMutex
+ // seed is a one-time random value initialized at stack startup
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
- // conns maintains a map of tuples needed for connection tracking for
- // iptables NAT rules. It is protected by mu.
- conns map[tupleID]tuple
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
+
+ // buckets is protected by mu.
+ buckets []bucket
+}
+
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
}
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader)
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
- tcpHeader := header.TCP(pkt.TransportHeader)
- if tcpHeader == nil {
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
}
@@ -136,15 +215,16 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
// newConn creates new connection.
func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
conn := conn{
- manip: manip,
- tcbHook: hook,
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
}
conn.original = tuple{conn: &conn, tupleID: orig}
conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
@@ -161,19 +241,35 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
if err != nil {
return nil, dirOriginal
}
+ return ct.connForTID(tid)
+}
- ct.mu.Lock()
- defer ct.mu.Unlock()
-
- tuple, ok := ct.conns[tid]
- if !ok {
- return nil, dirOriginal
+func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
}
- return tuple.conn, tuple.direction
+
+ return nil, dirOriginal
}
-// createConnFor creates a new conn for pkt.
-func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -186,8 +282,8 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.MinIP
- replyTID.srcPort = rt.MinPort
+ replyTID.srcAddr = rt.Addr
+ replyTID.srcPort = rt.Port
var manip manipType
switch hook {
case Prerouting:
@@ -196,23 +292,61 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
manip = manipDstOutput
}
conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
+}
- // Add the changed tuple to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked
- // list.
- ct.mu.Lock()
- defer ct.mu.Unlock()
- ct.conns[tid] = conn.original
- ct.conns[replyTID] = conn.reply
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
+ }
- return conn
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
+
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ }
+
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := pkt.Network()
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
// have their destinations modified and replies have their sources
@@ -228,14 +362,28 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
netHeader.SetSourceAddress(conn.original.dstAddr)
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := pkt.Network()
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
// have their destinations modified and replies have their sources
@@ -253,8 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- hdr := &pkt.Header
- length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -263,25 +410,39 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
if conn == nil {
- // Connection not found for the packet or the packet is invalid.
- return
+ return true
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
+ return false
}
switch hook {
@@ -297,35 +458,184 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
// other tcp states.
conn.mu.Lock()
defer conn.mu.Unlock()
- var st tcpconntrack.Result
- tcpHeader := header.TCP(pkt.TransportHeader)
- if conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else {
- switch hook {
- case conn.tcbHook:
- st = conn.tcb.UpdateStateOutbound(tcpHeader)
- default:
- st = conn.tcb.UpdateStateInbound(tcpHeader)
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
+ }
+
+ // We only track TCP connections.
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ return
+ }
+
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return
+ }
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ ct.insertConn(conn)
+}
+
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
+
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
}
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
+
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions:
+// * ct.mu is locked for reading.
+// * bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
+
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
}
+ ct.buckets[bucket].tuples.Remove(tuple)
- // Delete conn if tcp connection is closed.
- if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConn(conn)
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
}
+
+ return true
}
-// deleteConn deletes the connection.
-func (ct *ConnTrack) deleteConn(conn *conn) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
+ // Lookup the connection. The reply's original destination
+ // describes the original address.
+ tid := tupleID{
+ srcAddr: epID.LocalAddress,
+ srcPort: epID.LocalPort,
+ dstAddr: epID.RemoteAddress,
+ dstPort: epID.RemotePort,
+ transProto: header.TCPProtocolNumber,
+ netProto: netProto,
+ }
+ conn, _ := ct.connForTID(tid)
if conn == nil {
- return
+ // Not a tracked connection.
+ return "", 0, tcpip.ErrNotConnected
+ } else if conn.manip == manipNone {
+ // Unmanipulated connection.
+ return "", 0, tcpip.ErrInvalidOptionValue
}
- ct.mu.Lock()
- defer ct.mu.Unlock()
-
- delete(ct.conns, conn.original.tupleID)
- delete(ct.conns, conn.reply.tupleID)
+ return conn.original.dstAddr, conn.original.dstPort, nil
}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index a6546cef0..4e4b00a92 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -20,8 +20,10 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -44,37 +46,37 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
+ AddressableEndpointState
+
nicID tcpip.NICID
- id NetworkEndpointID
- prefixLen int
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
ep LinkEndpoint
}
-func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
+
+func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error {
+ return nil
}
-func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
+func (*fwdTestNetworkEndpoint) Enabled() bool {
+ return true
}
-func (f *fwdTestNetworkEndpoint) PrefixLen() int {
- return f.prefixLen
+func (*fwdTestNetworkEndpoint) Disable() {}
+
+func (f *fwdTestNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
- return &f.id
-}
-
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -85,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr
return 0
}
-func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -96,9 +94,9 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := pkt.Header.Prepend(fwdTestNetHeaderLen)
+ b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
b[dstAddrOffset] = r.RemoteAddress[0]
- b[srcAddrOffset] = f.id.LocalAddress[0]
+ b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
@@ -113,17 +111,28 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu
return tcpip.ErrNotSupported
}
-func (*fwdTestNetworkEndpoint) Close() {}
+func (f *fwdTestNetworkEndpoint) Close() {
+ f.AddressableEndpointState.Cleanup()
+}
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
+ neigh *neighborCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
+var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
@@ -141,42 +150,40 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add
}
func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
- netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen)
if !ok {
return 0, false, false
}
- pkt.NetworkHeader = netHeader
- pkt.Data.TrimFront(fwdTestNetHeaderLen)
- return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
+ return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
- return &fwdTestNetworkEndpoint{
- nicID: nicID,
- id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
+func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
+ e := &fwdTestNetworkEndpoint{
+ nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
- ep: ep,
- }, nil
+ ep: nic.LinkEndpoint(),
+ }
+ e.AddressableEndpointState.Init(e)
+ return e
}
-func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
+func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func (f *fwdTestNetworkProtocol) Close() {}
+func (*fwdTestNetworkProtocol) Close() {}
-func (f *fwdTestNetworkProtocol) Wait() {}
+func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
- if f.addrCache != nil && f.onLinkAddressResolved != nil {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr)
+ f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
})
}
return nil
@@ -189,10 +196,25 @@ func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip
return "", false
}
-func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) Forwarding() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.forwarding
+
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (f *fwdTestNetworkProtocol) SetForwarding(v bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.forwarding = v
+}
+
// fwdTestPacketInfo holds all the information about an outbound packet.
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
@@ -287,7 +309,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := fwdTestPacketInfo{
- Pkt: &PacketBuffer{Data: vv},
+ Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
}
select {
@@ -301,16 +323,29 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
-func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
- NetworkProtocols: []NetworkProtocol{proto},
+ NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }},
+ UseNeighborCache: useNeighborCache,
})
- proto.addrCache = s.linkAddrCache
+ if !useNeighborCache {
+ proto.addrCache = s.linkAddrCache
+ }
// Enable forwarding.
- s.SetForwarding(true)
+ s.SetForwarding(proto.Number(), true)
// NIC 1 has the link address "a", and added the network address 1.
ep1 = &fwdTestLinkEndpoint{
@@ -338,6 +373,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
t.Fatal("AddAddress #2 failed:", err)
}
+ if useNeighborCache {
+ // Control the neighbor cache for NIC 2.
+ nic, ok := s.nics[2]
+ if !ok {
+ t.Fatal("failed to get the neighbor cache for NIC 2")
+ }
+ proto.neigh = nic.neigh
+ }
+
// Route all packets to NIC 2.
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -351,79 +395,129 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}
func TestForwardingWithStaticResolver(t *testing.T) {
- // Create a network protocol with a static resolver.
- proto := &fwdTestNetworkProtocol{
- onResolveStaticAddress:
- // The network address 3 is resolved to the link address "c".
- func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "\x03" {
- return "c", true
- }
- return "", false
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
+ },
+ }
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache)
- var p fwdTestPacketInfo
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
+ var p fwdTestPacketInfo
- // Test that the static address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
func TestForwardingWithFakeResolver(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
- // Any address will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any address will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any address will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- var p fwdTestPacketInfo
+ var p fwdTestPacketInfo
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
@@ -431,15 +525,17 @@ func TestForwardingWithNoResolver(t *testing.T) {
// Create a network protocol without a resolver.
proto := &fwdTestNetworkProtocol{}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ // Whether or not we use the neighbor cache here does not matter since
+ // neither linkAddrCache nor neighborCache will be used.
+ ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */)
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
select {
case <-ep2.C:
@@ -449,202 +545,334 @@ func TestForwardingWithNoResolver(t *testing.T) {
}
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
- // Only packets to address 3 will be resolved to the
- // link address "c".
- if addr == "\x03" {
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
- }
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ }
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // Inject an inbound packet to address 4 on NIC 1. This packet should
- // not be forwarded.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 4
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf = buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
-
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
+
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
- if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
- }
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // Inject two inbound packets to address 3 on NIC 1.
- for i := 0; i < 2; i++ {
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- }
-
- for i := 0; i < 2; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
- }
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
- // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- // Set the packet sequence number.
- binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- }
-
- for i := 0; i < maxPendingPacketsPerResolution; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
- t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
- }
- seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
- if !ok {
- t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- // The first 5 packets should not be forwarded so the sequence number should
- // start with 5.
- want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
- t.Fatalf("got the packet #%d, want = #%d", n, want)
- }
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := PayloadSince(p.Pkt.NetworkHeader())
+ if b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ if len(b) < fwdTestNetHeaderLen+2 {
+ t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
+ }
+ seqNumBuf := b[fwdTestNetHeaderLen:]
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- for i := 0; i < maxPendingResolutions+5; i++ {
- // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
- // Each packet has a different destination address (3 to
- // maxPendingResolutions + 7).
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- }
-
- for i := 0; i < maxPendingResolutions; i++ {
- var p fwdTestPacketInfo
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- // The first 5 packets (address 3 to 7) should not be forwarded
- // because their address resolutions are interrupted.
- if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
- }
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go
new file mode 100644
index 000000000..5efddfaaf
--- /dev/null
+++ b/pkg/tcpip/stack/headertype_string.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type headerType ."; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[linkHeader-0]
+ _ = x[networkHeader-1]
+ _ = x[transportHeader-2]
+ _ = x[numHeaderType-3]
+}
+
+const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType"
+
+var _headerType_index = [...]uint8{0, 10, 23, 38, 51}
+
+func (i headerType) String() string {
+ if i < 0 || i >= headerType(len(_headerType_index)-1) {
+ return "headerType(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _headerType_name[_headerType_index[i]:_headerType_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 974d77c36..8d6d9a7f1 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -16,104 +16,186 @@ package stack
import (
"fmt"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
+// reaperDelay is how long to wait before starting to reap connections.
+const reaperDelay = 5 * time.Second
+
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() *IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
return &IPTables{
- tables: map[string]Table{
- TablenameNat: Table{
+ v4Tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
+ },
+ },
+ filterID: Table{
+ Rules: []Rule{
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ },
+ },
+ v6Tables: [numTables]Table{
+ natID: Table{
+ Rules: []Rule{
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ },
+ mangleID: Table{
+ Rules: []Rule{
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ },
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- UserChains: map[string]int{},
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
+ },
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: AcceptTarget{}},
- Rule{Target: ErrorTarget{}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
connections: ConnTrack{
- conns: make(map[tupleID]tuple),
+ seed: generateRandUint32(),
},
+ reaperDone: make(chan struct{}, 1),
}
}
@@ -122,62 +204,66 @@ func DefaultTables() *IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// GetTable returns table by name.
-func (it *IPTables) GetTable(name string) (Table, bool) {
+// GetTable returns a table by name.
+func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
it.mu.RLock()
defer it.mu.RUnlock()
- t, ok := it.tables[name]
- return t, ok
+ if ipv6 {
+ return it.v6Tables[id], true
+ }
+ return it.v4Tables[id], true
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) {
+func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
it.mu.Lock()
defer it.mu.Unlock()
+ // If iptables is being enabled, initialize the conntrack table and
+ // reaper.
+ if !it.modified {
+ it.connections.buckets = make([]bucket, numBuckets)
+ it.startReaper(reaperDelay)
+ }
it.modified = true
- it.tables[name] = table
-}
-
-// GetPriorities returns slice of priorities associated with hook.
-func (it *IPTables) GetPriorities(hook Hook) []string {
- it.mu.RLock()
- defer it.mu.RUnlock()
- return it.priorities[hook]
+ if ipv6 {
+ it.v6Tables[id] = table
+ } else {
+ it.v4Tables[id] = table
+ }
+ return nil
}
// A chainVerdict is what a table decides should be done with a packet.
@@ -199,26 +285,43 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
+// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
+// which address and nicName can be gathered. Currently, address is only
+// needed for prerouting and nicName is only needed for output.
+//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+ if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+ return true
+ }
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
it.mu.RLock()
+ defer it.mu.RUnlock()
if !it.modified {
- it.mu.RUnlock()
return true
}
- it.mu.RUnlock()
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.handlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.GetPriorities(hook) {
- table, _ := it.GetTable(tablename)
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ var table Table
+ if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber {
+ table = it.v6Tables[tableID]
+ } else {
+ table = it.v4Tables[tableID]
+ }
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -229,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v {
case RuleAccept:
continue
case RuleDrop:
@@ -245,17 +348,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
+// beforeSave is invoked by stateify.
+func (it *IPTables) beforeSave() {
+ // Ensure the reaper exits cleanly.
+ it.reaperDone <- struct{}{}
+ // Prevent others from modifying the connection table.
+ it.connections.mu.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (it *IPTables) afterLoad() {
+ it.startReaper(reaperDelay)
+}
+
+// startReaper starts a goroutine that wakes up periodically to reap timed out
+// connections.
+func (it *IPTables) startReaper(interval time.Duration) {
+ go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved.
+ bucket := 0
+ for {
+ select {
+ case <-it.reaperDone:
+ return
+ case <-time.After(interval):
+ bucket, interval = it.connections.reapUnused(bucket, interval)
+ }
+ }
+ }()
+}
+
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-//
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// * pkt.NetworkHeader is not nil.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
@@ -279,14 +424,14 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
return drop, natPkts
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
+// Preconditions:
+// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// * pkt.NetworkHeader is not nil.
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
case RuleAccept:
return chainAccept
@@ -303,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -326,25 +471,14 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
return chainDrop
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
+// Preconditions:
+// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// * pkt.NetworkHeader is not nil.
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
- // If pkt.NetworkHeader hasn't been set yet, it will be contained in
- // pkt.Data.
- if pkt.NetworkHeader == nil {
- var ok bool
- pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // Precondition has been violated.
- panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
- }
- }
-
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) {
+ if !rule.Filter.match(pkt, hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -363,5 +497,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
+ return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr)
+}
+
+// OriginalDst returns the original destination of redirected connections. It
+// returns an error if the connection doesn't exist or isn't redirected.
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ if !it.modified {
+ return "", 0, tcpip.ErrNotConnected
+ }
+ return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
new file mode 100644
index 000000000..529e02a07
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastUsed is invoked by stateify.
+func (cn *conn) saveLastUsed() unixTime {
+ return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
+}
+
+// loadLastUsed is invoked by stateify.
+func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.lastUsed = time.Unix(unix.second, unix.nano)
+}
+
+// beforeSave is invoked by stateify.
+func (ct *ConnTrack) beforeSave() {
+ ct.mu.Lock()
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d43f60c67..538c4625d 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -21,117 +21,178 @@ import (
)
// AcceptTarget accepts packets.
-type AcceptTarget struct{}
+type AcceptTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (at *AcceptTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: at.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
// DropTarget drops packets.
-type DropTarget struct{}
+type DropTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (dt *DropTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: dt.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
+// ErrorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const ErrorTargetName = "ERROR"
+
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
-type ErrorTarget struct{}
+type ErrorTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (et *ErrorTarget) ID() TargetID {
+ return TargetID{
+ Name: ErrorTargetName,
+ NetworkProtocol: et.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
// UserChainTarget marks a rule as the beginning of a user chain.
type UserChainTarget struct {
+ // Name is the chain name.
Name string
+
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (uc *UserChainTarget) ID() TargetID {
+ return TargetID{
+ Name: ErrorTargetName,
+ NetworkProtocol: uc.NetworkProtocol,
+ }
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
// ReturnTarget returns from the current chain. If the chain is a built-in, the
// hook's underflow should be called.
-type ReturnTarget struct{}
+type ReturnTarget struct {
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// ID implements Target.ID.
+func (rt *ReturnTarget) ID() TargetID {
+ return TargetID{
+ NetworkProtocol: rt.NetworkProtocol,
+ }
+}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
+// RedirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const RedirectTargetName = "REDIRECT"
+
// RedirectTarget redirects the packet by modifying the destination port/IP.
-// Min and Max values for IP and Ports in the struct indicate the range of
-// values which can be used to redirect.
+// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
+// them.
type RedirectTarget struct {
- // TODO(gvisor.dev/issue/170): Other flags need to be added after
- // we support them.
- // RangeProtoSpecified flag indicates single port is specified to
- // redirect.
- RangeProtoSpecified bool
+ // Addr indicates address used to redirect.
+ Addr tcpip.Address
- // MinIP indicates address used to redirect.
- MinIP tcpip.Address
+ // Port indicates port used to redirect.
+ Port uint16
- // MaxIP indicates address used to redirect.
- MaxIP tcpip.Address
-
- // MinPort indicates port used to redirect.
- MinPort uint16
+ // NetworkProtocol is the network protocol the target is used with.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
- // MaxPort indicates port used to redirect.
- MaxPort uint16
+// ID implements Target.ID.
+func (rt *RedirectTarget) ID() TargetID {
+ return TargetID{
+ Name: RedirectTargetName,
+ NetworkProtocol: rt.NetworkProtocol,
+ }
}
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
}
// Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1) in Output and
- // to primary address of the incoming interface in Prerouting.
+ // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
- rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
- rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ } else {
+ rt.Addr = header.IPv6Loopback
+ }
case Prerouting:
- rt.MinIP = address
- rt.MaxIP = address
+ rt.Addr = address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
- netHeader := header.IPv4(pkt.NetworkHeader)
- switch protocol := netHeader.TransportProtocol(); protocol {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader)
- udpHeader.SetDestinationPort(rt.MinPort)
+ udpHeader := header.UDP(pkt.TransportHeader().View())
+ udpHeader.SetDestinationPort(rt.Port)
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
- hdr := &pkt.Header
- length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(protocol, length)
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
@@ -140,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- // Change destination address.
- netHeader.SetDestinationAddress(rt.MinIP)
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+
+ pkt.Network().SetDestinationAddress(rt.Addr)
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -153,7 +219,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.createConnFor(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index c528ec381..7b3f3e88b 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"strings"
"sync"
@@ -78,50 +79,66 @@ const (
)
// IPTables holds all the tables for a netstack.
+//
+// +stateify savable
type IPTables struct {
- // mu protects tables, priorities, and modified.
+ // mu protects v4Tables, v6Tables, and modified.
mu sync.RWMutex
-
- // tables maps table names to tables. User tables have arbitrary names.
- // mu needs to be locked for accessing.
- tables map[string]Table
-
- // priorities maps each hook to a list of table names. The order of the
- // list is the order in which each table should be visited for that
- // hook. mu needs to be locked for accessing.
- priorities map[Hook][]string
-
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables. mu must be locked for accessing.
+ v4Tables [numTables]Table
+ v6Tables [numTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
modified bool
+ // priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook. It is immutable.
+ priorities [NumHooks][]tableID
+
connections ConnTrack
+
+ // reaperDone can be signaled to stop the reaper goroutine.
+ reaperDone chan struct{}
}
-// A Table defines a set of chains and hooks into the network stack. It is
-// really just a list of rules.
+// A Table defines a set of chains and hooks into the network stack.
+//
+// It is a list of Rules, entry points (BuiltinChains), and error handlers
+// (Underflows). As packets traverse netstack, they hit hooks. When a packet
+// hits a hook, iptables compares it to Rules starting from that hook's entry
+// point. So if a packet hits the Input hook, we look up the corresponding
+// entry point in BuiltinChains and jump to that point.
+//
+// If the Rule doesn't match the packet, iptables continues to the next Rule.
+// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept
+// or RuleDrop) that causes the packet to stop traversing iptables. It can also
+// jump to other rules or perform custom actions based on Rule.Target.
+//
+// Underflow Rules are invoked when a chain returns without reaching a verdict.
+//
+// +stateify savable
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
@@ -130,6 +147,8 @@ func (table *Table) ValidHooks() uint32 {
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
+//
+// +stateify savable
type Rule struct {
// Filter holds basic IP filtering fields common to every rule.
Filter IPHeaderFilter
@@ -141,11 +160,18 @@ type Rule struct {
Target Target
}
-// IPHeaderFilter holds basic IP filtering data common to every rule.
+// IPHeaderFilter performs basic IP header matching common to every rule.
+//
+// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
+ // CheckProtocol determines whether the Protocol field should be
+ // checked during matching.
+ // TODO(gvisor.dev/issue/3549): Check this field during matching.
+ CheckProtocol bool
+
// Dst matches the destination IP address.
Dst tcpip.Address
@@ -182,16 +208,43 @@ type IPHeaderFilter struct {
OutputInterfaceInvert bool
}
-// match returns whether hdr matches the filter.
-func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+// match returns whether pkt matches the filter.
+//
+// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
+// or IPv6 header length.
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+ // Extract header fields.
+ var (
+ // TODO(gvisor.dev/issue/170): Support other filter fields.
+ transProto tcpip.TransportProtocolNumber
+ dstAddr tcpip.Address
+ srcAddr tcpip.Address
+ )
+ switch proto := pkt.NetworkProtocolNumber; proto {
+ case header.IPv4ProtocolNumber:
+ hdr := header.IPv4(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ case header.IPv6ProtocolNumber:
+ hdr := header.IPv6(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ default:
+ panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto))
+ }
+
// Check the transport protocol.
- if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ if fl.CheckProtocol && fl.Protocol != transProto {
return false
}
- // Check the source and destination IPs.
- if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ // Check the addresses.
+ if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) ||
+ !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) {
return false
}
@@ -219,6 +272,18 @@ func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool
return true
}
+// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
+// applies.
+func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber {
+ switch len(fl.Src) {
+ case header.IPv4AddressSize:
+ return header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src))
+}
+
// filterAddress returns whether addr matches the filter.
func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
matches := true
@@ -244,8 +309,23 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
+// A TargetID uniquely identifies a target.
+type TargetID struct {
+ // Name is the target name as stored in the xt_entry_target struct.
+ Name string
+
+ // NetworkProtocol is the protocol to which the target applies.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+
+ // Revision is the version of the target.
+ Revision uint8
+}
+
// A Target is the interface for taking an action for a packet.
type Target interface {
+ // ID uniquely identifies the Target.
+ ID() TargetID
+
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 403557fd7..6f73a0ce4 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 1baa498d0..33806340e 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+ "math"
"sync/atomic"
"testing"
"time"
@@ -48,7 +49,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
if f := r.onLinkAddressRequest; f != nil {
f()
@@ -191,7 +192,13 @@ func TestCacheReplace(t *testing.T) {
}
func TestCacheResolution(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
+ // There is a race condition causing this test to fail when the executor
+ // takes longer than the resolution timeout to call linkAddrCache.get. This
+ // is especially common when this test is run with gotsan.
+ //
+ // Using a large resolution timeout decreases the probability of experiencing
+ // this race condition and does not affect how long this test takes to run.
+ c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
@@ -275,3 +282,71 @@ func TestStaticResolution(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}
+
+// TestCacheWaker verifies that RemoveWaker removes a waker previously added
+// through get().
+func TestCacheWaker(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+
+ // First, sanity check that wakers are working.
+ {
+ linkRes := &testLinkAddressResolver{cache: c}
+ s := sleep.Sleeper{}
+ defer s.Done()
+
+ const wakerID = 1
+ w := sleep.Waker{}
+ s.AddWaker(&w, wakerID)
+
+ e := testAddrs[0]
+
+ if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
+ }
+ id, ok := s.Fetch(true /* block */)
+ if !ok {
+ t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)")
+ }
+ if id != wakerID {
+ t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID)
+ }
+
+ if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil {
+ t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
+ } else if got != e.linkAddr {
+ t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
+ }
+ }
+
+ // Check that RemoveWaker works.
+ {
+ linkRes := &testLinkAddressResolver{cache: c}
+ s := sleep.Sleeper{}
+ defer s.Done()
+
+ const wakerID = 2 // different than the ID used in the sanity check
+ w := sleep.Waker{}
+ s.AddWaker(&w, wakerID)
+
+ e := testAddrs[1]
+ linkRes.onLinkAddressRequest = func() {
+ // Remove the waker before the linkAddrCache has the opportunity to send
+ // a notification.
+ c.removeWaker(e.addr, &w)
+ }
+
+ if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
+ }
+
+ if got, err := getBlocking(c, e.addr, linkRes); err != nil {
+ t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
+ } else if got != e.linkAddr {
+ t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
+ }
+
+ if id, ok := s.Fetch(false /* block */); ok {
+ t.Fatalf("unexpected notification from waker with id %d", id)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6f86abc98..73a01c2dd 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -150,10 +150,10 @@ type ndpDNSSLEvent struct {
type ndpDHCPv6Event struct {
nicID tcpip.NICID
- configuration stack.DHCPv6ConfigurationFromNDPRA
+ configuration ipv6.DHCPv6ConfigurationFromNDPRA
}
-var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
// related events happen for test purposes.
@@ -170,7 +170,7 @@ type ndpDispatcher struct {
dhcpv6ConfigurationC chan ndpDHCPv6Event
}
-// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus.
func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
if n.dadC != nil {
n.dadC <- ndpDADEvent{
@@ -182,7 +182,7 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add
}
}
-// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
+// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered.
func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
if c := n.routerC; c != nil {
c <- ndpRouterEvent{
@@ -195,7 +195,7 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.
return n.rememberRouter
}
-// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
+// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated.
func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
if c := n.routerC; c != nil {
c <- ndpRouterEvent{
@@ -206,7 +206,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip
}
}
-// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
+// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
@@ -219,7 +219,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip
return n.rememberPrefix
}
-// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
+// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
@@ -261,7 +261,7 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi
}
}
-// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
+// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption.
func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
if c := n.rdnssC; c != nil {
c <- ndpRDNSSEvent{
@@ -274,7 +274,7 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
-// Implements stack.NDPDispatcher.OnDNSSearchListOption.
+// Implements ipv6.NDPDispatcher.OnDNSSearchListOption.
func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
if n.dnsslC != nil {
n.dnsslC <- ndpDNSSLEvent{
@@ -285,8 +285,8 @@ func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []s
}
}
-// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
-func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
if c := n.dhcpv6ConfigurationC; c != nil {
c <- ndpDHCPv6Event{
nicID,
@@ -319,13 +319,12 @@ func TestDADDisabled(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- }
-
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -413,19 +412,21 @@ func TestDADResolve(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = test.retransTimer
- opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
e := channelLinkWithHeaderLength{
Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
headerLength: test.linkHeaderLen,
}
e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ipv6.NDPConfigurations{
+ RetransmitTimer: test.retransTimer,
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ },
+ })},
+ })
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -541,7 +542,7 @@ func TestDADResolve(t *testing.T) {
// As per RFC 4861 section 4.3, a possible option is the Source Link
// Layer option, but this option MUST NOT be included when the source
// address of the packet is the unspecified address.
- checker.IPv6(t, p.Pkt.Header.View(),
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
checker.DstAddr(snmc),
checker.TTL(header.NDPHopLimit),
@@ -550,14 +551,34 @@ func TestDADResolve(t *testing.T) {
checker.NDPNSOptions(nil),
))
- if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
}
})
}
}
+func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(tgt)
+ snmc := header.SolicitedNodeAddr(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
+}
+
// TestDADFail tests to make sure that the DAD process fails if another node is
// detected to be performing DAD on the same address (receive an NS message from
// a node doing DAD for the same address), or if another node is detected to own
@@ -567,39 +588,19 @@ func TestDADFail(t *testing.T) {
tests := []struct {
name string
- makeBuf func(tgt tcpip.Address) buffer.Prependable
+ rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
{
- "RxSolicit",
- func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(tgt)
- snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
- })
-
- return hdr
-
- },
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RxSolicit",
+ rxPkt: rxNDPSolicit,
+ getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborSolicit
},
},
{
- "RxAdvert",
- func(tgt tcpip.Address) buffer.Prependable {
+ name: "RxAdvert",
+ rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) {
naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
@@ -621,11 +622,9 @@ func TestDADFail(t *testing.T) {
SrcAddr: tgt,
DstAddr: header.IPv6AllNodesMulticastAddress,
})
-
- return hdr
-
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborAdvert
},
},
@@ -636,16 +635,16 @@ func TestDADFail(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := stack.DefaultNDPConfigurations()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = time.Second * 2
+ ndpConfigs := ipv6.DefaultNDPConfigurations()
+ ndpConfigs.RetransmitTimer = time.Second * 2
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -664,12 +663,8 @@ func TestDADFail(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
- // Receive a packet to simulate multiple nodes owning or
- // attempting to own the same address.
- hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
+ // Receive a packet to simulate an address conflict.
+ test.rxPkt(e, addr1)
stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
if got := stat.Value(); got != 1 {
@@ -753,18 +748,19 @@ func TestDADStop(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := stack.NDPConfigurations{
+
+ ndpConfigs := ipv6.NDPConfigurations{
RetransmitTimer: time.Second,
DupAddrDetectTransmits: 2,
}
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ })},
+ })
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
@@ -814,19 +810,6 @@ func TestDADStop(t *testing.T) {
}
}
-// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
-// we attempt to update NDP configurations using an invalid NICID.
-func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- })
-
- // No NIC with ID 1 yet.
- if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
- }
-}
-
// TestSetNDPConfigurations tests that we can update and use per-interface NDP
// configurations without affecting the default NDP configurations or other
// interfaces' configurations.
@@ -862,8 +845,9 @@ func TestSetNDPConfigurations(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDisp,
+ })},
})
expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
@@ -891,12 +875,15 @@ func TestSetNDPConfigurations(t *testing.T) {
}
// Update the NDP configurations on NIC(1) to use DAD.
- configs := stack.NDPConfigurations{
+ configs := ipv6.NDPConfigurations{
DupAddrDetectTransmits: test.dupAddrDetectTransmits,
RetransmitTimer: test.retransmitTimer,
}
- if err := s.SetNDPConfigurations(nicID1, configs); err != nil {
- t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(configs)
}
// Created after updating NIC(1)'s NDP configurations
@@ -1024,7 +1011,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
@@ -1110,14 +1099,15 @@ func TestNoRouterDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverDefaultRouters: discover,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverDefaultRouters: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1148,12 +1138,13 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1189,12 +1180,13 @@ func TestRouterDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
expectRouterEvent := func(addr tcpip.Address, discovered bool) {
@@ -1254,7 +1246,7 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1271,7 +1263,7 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1282,7 +1274,7 @@ func TestRouterDiscovery(t *testing.T) {
}
// TestRouterDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
+// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
ndpDisp := ndpDispatcher{
routerC: make(chan ndpRouterEvent, 1),
@@ -1290,12 +1282,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1303,14 +1296,14 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
// Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
+ for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
linkAddr[5] = byte(i)
llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
- if i <= stack.MaxDiscoveredDefaultRouters {
+ if i <= ipv6.MaxDiscoveredDefaultRouters {
select {
case e := <-ndpDisp.routerC:
if diff := checkRouterEvent(e, llAddr, true); diff != "" {
@@ -1355,14 +1348,15 @@ func TestNoPrefixDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverOnLinkPrefixes: discover,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverOnLinkPrefixes: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1396,13 +1390,14 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1442,12 +1437,13 @@ func TestPrefixDiscovery(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1502,7 +1498,7 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
@@ -1542,12 +1538,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1618,33 +1615,34 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
// TestPrefixDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
+// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
+ prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
rememberPrefix: true,
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
}
- optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
- prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
+ optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2)
+ prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
// Receive an RA with 2 more than the max number of discovered on-link
// prefixes.
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
prefixAddr[7] = byte(i)
prefix := tcpip.AddressWithPrefix{
@@ -1662,8 +1660,8 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
- if i < stack.MaxDiscoveredOnLinkPrefixes {
+ for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ if i < ipv6.MaxDiscoveredOnLinkPrefixes {
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
@@ -1689,13 +1687,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix)
AddressWithPrefix: item,
}
- for _, i := range list {
- if i == protocolAddress {
- return true
- }
- }
-
- return false
+ return containsAddr(list, protocolAddress)
}
// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
@@ -1719,14 +1711,15 @@ func TestNoAutoGenAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- AutoGenGlobalAddresses: autogen,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handle,
+ AutoGenGlobalAddresses: autogen,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1752,14 +1745,14 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
-func TestAutoGenAddr(t *testing.T) {
+func TestAutoGenAddr2(t *testing.T) {
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1769,12 +1762,13 @@ func TestAutoGenAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -1879,14 +1873,14 @@ func TestAutoGenTempAddr(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate
- savedMaxDesync := stack.MaxDesyncFactor
+ savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate
+ savedMaxDesync := ipv6.MaxDesyncFactor
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
- stack.MaxDesyncFactor = savedMaxDesync
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
+ ipv6.MaxDesyncFactor = savedMaxDesync
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ ipv6.MaxDesyncFactor = time.Nanosecond
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1934,16 +1928,17 @@ func TestAutoGenTempAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- TempIIDSeed: seed,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2122,11 +2117,11 @@ func TestAutoGenTempAddr(t *testing.T) {
func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
const nicID = 1
- savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
}()
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MaxDesyncFactor = time.Nanosecond
tests := []struct {
name string
@@ -2163,12 +2158,13 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenIPv6LinkLocal: true,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2214,11 +2210,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
retransmitTimer = 2 * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
}()
- stack.MaxDesyncFactor = 0
+ ipv6.MaxDesyncFactor = 0
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2231,15 +2227,16 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2297,17 +2294,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
- stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
}()
- stack.MaxDesyncFactor = 0
- stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+ ipv6.MaxDesyncFactor = 0
+ ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2320,16 +2317,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2385,8 +2383,11 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
// Stop generating temporary addresses
ndpConfigs.AutoGenTempGlobalAddresses = false
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
}
// Wait for all the temporary addresses to get invalidated.
@@ -2395,7 +2396,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2432,9 +2433,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2442,17 +2443,17 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
newMinVLDuration = newMinVL * time.Second
)
- savedMaxDesyncFactor := stack.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ savedMaxDesyncFactor := ipv6.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
defer func() {
- stack.MaxDesyncFactor = savedMaxDesyncFactor
- stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ ipv6.MaxDesyncFactor = savedMaxDesyncFactor
+ ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
}()
- stack.MaxDesyncFactor = 0
- stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+ ipv6.MaxDesyncFactor = 0
+ ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
@@ -2465,16 +2466,17 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2533,7 +2535,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2548,9 +2550,12 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
// as paased.
ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2559,18 +2564,16 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
- }
+ ndpEP.SetNDPConfigurations(ndpConfigs)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
@@ -2658,20 +2661,21 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := stack.NDPConfigurations{
+ ndpConfigs := ipv6.NDPConfigurations{
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: test.tempAddrs,
AutoGenAddressConflictRetries: 1,
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: test.nicNameFromID,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: test.nicNameFromID,
+ },
+ })},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
s.SetRouteTable([]tcpip.Route{{
@@ -2742,8 +2746,11 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
ndpDisp.dadC = make(chan ndpDADEvent, 2)
ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits
ndpConfigs.RetransmitTimer = retransmitTimer
- if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
- t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else {
+ ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
+ ndpEP.SetNDPConfigurations(ndpConfigs)
}
// Do SLAAC for prefix.
@@ -2757,9 +2764,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// DAD failure to restart the local generation process.
addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1]
expectAutoGenAddrAsyncEvent(addr, newAddr)
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
@@ -2790,20 +2795,22 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// stack.Stack will have a default route through the router (llAddr3) installed
// and a static link-address (linkAddr3) added to the link address cache for the
// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
t.Helper()
ndpDisp := &ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2813,7 +2820,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
Gateway: llAddr3,
NIC: nicID,
}})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ if useNeighborCache {
+ s.AddStaticNeighbor(nicID, llAddr3, linkAddr3)
+ } else {
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ }
return ndpDisp, e, s
}
@@ -2887,329 +2898,366 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
// receiving a PI with 0 preferred lifetime.
func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
- const nicID = 1
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ const nicID = 1
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
}
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- expectPrimaryAddr(addr1)
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
- // Deprecate addr for prefix1 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- // addr should still be the primary endpoint as there are no other addresses.
- expectPrimaryAddr(addr1)
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
- // Refresh lifetimes of addr generated from prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Deprecate addr for prefix2 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr1 should be the primary endpoint now since addr2 is deprecated but
- // addr1 is not.
- expectPrimaryAddr(addr1)
- // addr2 is deprecated but if explicitly requested, it should be used.
- fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Another PI w/ 0 preferred lifetime should not result in a deprecation
- // event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Refresh lifetimes of addr generated from prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ })
}
- expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
}
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
- expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
}
- case <-time.After(timeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Receive a PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr1)
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
- // Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since addr1 is deprecated but
- // addr2 is not.
- expectPrimaryAddr(addr2)
- // addr1 is deprecated but if explicitly requested, it should be used.
- fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
- // sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- // addr1 is the primary endpoint again since it is non-deprecated now.
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since it is not deprecated.
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
- // Wait for a deprecation then invalidation events, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation handlers could be handled in
- // either deprecation then invalidation, or invalidation then deprecation
- // (which should be cancelled by the invalidation handler).
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
select {
case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
- // event after.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
}
- } else {
- t.Fatalf("got unexpected auto-generated event")
- }
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should not have %s in the list of addresses", addr2)
- }
- // Should not have any primary endpoints.
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
- }
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
- if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+ })
}
}
@@ -3219,12 +3267,12 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
const infiniteVLSeconds = 2
const minVLSeconds = 1
savedIL := header.NDPInfiniteLifetime
- savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
header.NDPInfiniteLifetime = savedIL
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3268,12 +3316,13 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3318,11 +3367,11 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
const newMinVL = 4
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
}()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3410,12 +3459,13 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
}
e := channel.New(10, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3476,12 +3526,13 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3513,8 +3564,8 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
@@ -3527,110 +3578,128 @@ func TestAutoGenAddrRemoval(t *testing.T) {
func TestAutoGenAddrAfterRemoval(t *testing.T) {
const nicID = 1
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
}
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
-
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- // Receive a PI to auto-generate addr1 with a large valid and preferred
- // lifetime.
- const largeLifetimeSeconds = 999
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectPrimaryAddr(addr1)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
- // Add addr2 as a static address.
- protoAddr2 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr2,
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
- }
- // addr2 should be more preferred now since it is at the front of the primary
- // list.
- expectPrimaryAddr(addr2)
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- // Get a route using addr2 to increment its reference count then remove it
- // to leave it in the permanentExpired state.
- r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
- }
- defer r.Release()
- if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
- }
- // addr1 should be preferred again since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- // Receive a PI to auto-generate addr2 as valid and preferred.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr2 should be more preferred now that it is closer to the front of the
- // primary list and not deprecated.
- expectPrimaryAddr(addr2)
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Removing the address should result in an invalidation event immediately.
- // It should still be in the permanentExpired state because r is still held.
- //
- // We remove addr2 here to make sure addr2 was marked as a SLAAC address
- // (it was previously marked as a static address).
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- // addr1 should be more preferred since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
- // Receive a PI to auto-generate addr2 as valid and deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr1 should still be more preferred since addr2 is deprecated, even though
- // it is closer to the front of the primary list.
- expectPrimaryAddr(addr1)
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
- // Receive a PI to refresh addr2's preferred lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto gen addr event")
- default:
- }
- // addr2 should be more preferred now that it is not deprecated.
- expectPrimaryAddr(addr2)
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+ })
}
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- expectPrimaryAddr(addr1)
}
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
@@ -3643,12 +3712,13 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
@@ -3724,18 +3794,19 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
},
- SecretKey: secretKey,
- },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -3799,11 +3870,11 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
const lifetimeSeconds = 10
// Needed for the temporary address sub test.
- savedMaxDesync := stack.MaxDesyncFactor
+ savedMaxDesync := ipv6.MaxDesyncFactor
defer func() {
- stack.MaxDesyncFactor = savedMaxDesync
+ ipv6.MaxDesyncFactor = savedMaxDesync
}()
- stack.MaxDesyncFactor = time.Nanosecond
+ ipv6.MaxDesyncFactor = time.Nanosecond
var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
secretKey := secretKeyBuf[:]
@@ -3881,14 +3952,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
addrTypes := []struct {
name string
- ndpConfigs stack.NDPConfigurations
+ ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
}{
{
name: "Global address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -3906,7 +3977,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
},
{
name: "LinkLocal address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
},
@@ -3920,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
},
{
name: "Temporary address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -3972,16 +4043,17 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs := addrType.ndpConfigs
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
},
- SecretKey: secretKey,
- },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4002,9 +4074,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
// Simulate a DAD conflict.
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
expectDADEvent(t, &ndpDisp, addr.Address, false)
@@ -4062,14 +4132,14 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
addrTypes := []struct {
name string
- ndpConfigs stack.NDPConfigurations
+ ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
subnet tcpip.Subnet
triggerSLAACFn func(e *channel.Endpoint)
}{
{
name: "Global address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
HandleRAs: true,
@@ -4085,7 +4155,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
},
{
name: "LinkLocal address",
- ndpConfigs: stack.NDPConfigurations{
+ ndpConfigs: ipv6.NDPConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
AutoGenAddressConflictRetries: maxRetries,
@@ -4108,10 +4178,11 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4141,9 +4212,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
expectAutoGenAddrEvent(addr, newAddr)
// Simulate a DAD conflict.
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
@@ -4193,21 +4262,22 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenAddressConflictRetries: maxRetries,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
},
- SecretKey: secretKey,
- },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })},
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4239,9 +4309,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
// Simulate a DAD conflict after some time has passed.
time.Sleep(failureTimer)
- if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
- t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
- }
+ rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
@@ -4402,11 +4470,12 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -4452,11 +4521,12 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -4583,7 +4653,7 @@ func TestCleanupNDPState(t *testing.T) {
name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
keepAutoGenLinkLocal: true,
maxAutoGenAddrEvents: 4,
@@ -4637,15 +4707,16 @@ func TestCleanupNDPState(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
expectRouterEvent := func() (bool, ndpRouterEvent) {
@@ -4910,18 +4981,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
t.Helper()
select {
case e := <-ndpDisp.dhcpv6ConfigurationC:
@@ -4945,7 +5017,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Even if the first RA reports no DHCPv6 configurations are available, the
// dispatcher should get an event.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
// Receiving the same update again should not result in an event to the
// dispatcher.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
@@ -4954,19 +5026,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
// Receive an RA that updates the DHCPv6 configuration to Managed Address.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
expectNoDHCPv6Event()
// Receive an RA that updates the DHCPv6 configuration to none.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
expectNoDHCPv6Event()
@@ -4974,7 +5046,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
//
// Note, when the M flag is set, the O flag is redundant.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
expectNoDHCPv6Event()
// Even though the DHCPv6 flags are different, the effective configuration is
@@ -4987,7 +5059,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
@@ -5002,7 +5074,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
// Receive an RA that updates the DHCPv6 configuration to Other
// Configurations.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
}
@@ -5140,16 +5212,15 @@ func TestRouterSolicitation(t *testing.T) {
t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
- checker.IPv6(t,
- p.Pkt.Header.View(),
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(test.expectedSrcAddr),
checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
)
- if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
}
waitForNothing := func(timeout time.Duration) {
@@ -5161,12 +5232,13 @@ func TestRouterSolicitation(t *testing.T) {
}
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
})
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -5230,11 +5302,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(false)
+ s.SetForwarding(ipv6.ProtocolNumber, false)
},
stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
},
@@ -5294,19 +5366,20 @@ func TestStopStartSolicitingRouters(t *testing.T) {
if p.Proto != header.IPv6ProtocolNumber {
t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
- checker.IPv6(t, p.Pkt.Header.View(),
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS())
}
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
- },
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
new file mode 100644
index 000000000..27e1feec0
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -0,0 +1,333 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const neighborCacheSize = 512 // max entries per interface
+
+// neighborCache maps IP addresses to link addresses. It uses the Least
+// Recently Used (LRU) eviction strategy to implement a bounded cache for
+// dynmically acquired entries. It contains the state machine and configuration
+// for running Neighbor Unreachability Detection (NUD).
+//
+// There are two types of entries in the neighbor cache:
+// 1. Dynamic entries are discovered automatically by neighbor discovery
+// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm
+// reachability with the device once the entry's state becomes Stale.
+// 2. Static entries are explicitly added by a user and have no expiration.
+// Their state is always Static. The amount of static entries stored in the
+// cache is unbounded.
+//
+// neighborCache implements NUDHandler.
+type neighborCache struct {
+ nic *NIC
+ state *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ cache map[tcpip.Address]*neighborEntry
+ dynamic struct {
+ lru neighborEntryList
+
+ // count tracks the amount of dynamic entries in the cache. This is
+ // needed since static entries do not count towards the LRU cache
+ // eviction strategy.
+ count uint16
+ }
+}
+
+var _ NUDHandler = (*neighborCache)(nil)
+
+// getOrCreateEntry retrieves a cache entry associated with addr. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[remoteAddr]; ok {
+ entry.mu.RLock()
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.lru.PushFront(entry)
+ }
+ entry.mu.RUnlock()
+ return entry
+ }
+
+ // The entry that needs to be created must be dynamic since all static
+ // entries are directly added to the cache via addStaticEntry.
+ entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ if n.dynamic.count == neighborCacheSize {
+ e := n.dynamic.lru.Back()
+ e.mu.Lock()
+
+ delete(n.cache, e.neigh.Addr)
+ n.dynamic.lru.Remove(e)
+ n.dynamic.count--
+
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Unknown)
+ e.notifyWakersLocked()
+ e.mu.Unlock()
+ }
+ n.cache[remoteAddr] = entry
+ n.dynamic.lru.PushFront(entry)
+ n.dynamic.count++
+ return entry
+}
+
+// entry looks up the neighbor cache for translating address to link address
+// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
+// is a LinkAddressResolver registered with the network protocol, the cache
+// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
+// provided, it will be notified when address resolution is complete (success
+// or not).
+//
+// If address resolution is required, ErrNoLinkAddress and a notification
+// channel is returned for the top level caller to block. Channel is closed
+// once address resolution is complete (success or not).
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
+ e := NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ }
+ return e, nil, nil
+ }
+
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ switch s := entry.neigh.State; s {
+ case Reachable, Static:
+ return entry.neigh, nil, nil
+
+ case Unknown, Incomplete, Stale, Delay, Probe:
+ entry.addWakerLocked(w)
+
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+ }
+ entry.done = make(chan struct{})
+ }
+
+ entry.handlePacketQueuedLocked()
+ return entry.neigh, entry.done, tcpip.ErrWouldBlock
+
+ case Failed:
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", s))
+ }
+}
+
+// removeWaker removes a waker that has been added when link resolution for
+// addr was requested.
+func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
+ n.mu.Lock()
+ if entry, ok := n.cache[addr]; ok {
+ delete(entry.wakers, waker)
+ }
+ n.mu.Unlock()
+}
+
+// entries returns all entries in the neighbor cache.
+func (n *neighborCache) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(n.cache))
+ n.mu.RLock()
+ for _, entry := range n.cache {
+ entry.mu.RLock()
+ entries = append(entries, entry.neigh)
+ entry.mu.RUnlock()
+ }
+ n.mu.RUnlock()
+ return entries
+}
+
+// addStaticEntry adds a static entry to the neighbor cache, mapping an IP
+// address to a link address. If a dynamic entry exists in the neighbor cache
+// with the same address, it will be replaced with this static entry. If a
+// static entry exists with the same address but different link address, it
+// will be updated with the new link address. If a static entry exists with the
+// same address and link address, nothing will happen.
+func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[addr]; ok {
+ entry.mu.Lock()
+ if entry.neigh.State != Static {
+ // Dynamic entry found with the same address.
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ } else if entry.neigh.LinkAddr == linkAddr {
+ // Static entry found with the same address and link address.
+ entry.mu.Unlock()
+ return
+ } else {
+ // Static entry found with the same address but different link address.
+ entry.neigh.LinkAddr = linkAddr
+ entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.mu.Unlock()
+ return
+ }
+
+ // Notify that resolution has been interrupted, just in case the entry was
+ // in the Incomplete or Probe state.
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
+ n.cache[addr] = entry
+}
+
+// removeEntryLocked removes the specified entry from the neighbor cache.
+func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+ if entry.neigh.State != Failed {
+ entry.dispatchRemoveEventLocked()
+ }
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+
+ delete(n.cache, entry.neigh.Addr)
+}
+
+// removeEntry removes a dynamic or static entry by address from the neighbor
+// cache. Returns true if the entry was found and deleted.
+func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ entry, ok := n.cache[addr]
+ if !ok {
+ return false
+ }
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ n.removeEntryLocked(entry)
+ return true
+}
+
+// clear removes all dynamic and static entries from the neighbor cache.
+func (n *neighborCache) clear() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ for _, entry := range n.cache {
+ entry.mu.Lock()
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ n.dynamic.lru = neighborEntryList{}
+ n.cache = make(map[tcpip.Address]*neighborEntry)
+ n.dynamic.count = 0
+}
+
+// config returns the NUD configuration.
+func (n *neighborCache) config() NUDConfigurations {
+ return n.state.Config()
+}
+
+// setConfig changes the NUD configuration.
+//
+// If config contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *neighborCache) setConfig(config NUDConfigurations) {
+ config.resetInvalidFields()
+ n.state.SetConfig(config)
+}
+
+// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
+// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
+// by the caller.
+func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry.mu.Lock()
+ entry.handleProbeLocked(remoteLinkAddr)
+ entry.mu.Unlock()
+}
+
+// HandleConfirmation implements NUDHandler.HandleConfirmation by following the
+// logic defined in RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce cryptographically
+// generated addresses, as defined in RFC 3972, Cryptographically Generated
+// Addresses (CGA). This ensures that the claimed source of an NDP message is
+// the owner of the claimed address.
+func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleConfirmationLocked(linkAddr, flags)
+ entry.mu.Unlock()
+ }
+ // The confirmation SHOULD be silently discarded if the recipient did not
+ // initiate any communication with the target. This is indicated if there is
+ // no matching entry for the remote address.
+}
+
+// HandleUpperLevelConfirmation implements
+// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in
+// RFC 4861 section 7.3.1.
+func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleUpperLevelConfirmationLocked()
+ entry.mu.Unlock()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
new file mode 100644
index 000000000..a0b7da5cd
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -0,0 +1,1727 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+)
+
+const (
+ // entryStoreSize is the default number of entries that will be generated and
+ // added to the entry store. This number needs to be larger than the size of
+ // the neighbor cache to give ample opportunity for verifying behavior during
+ // cache overflows. Four times the size of the neighbor cache allows for
+ // three complete cache overflows.
+ entryStoreSize = 4 * neighborCacheSize
+
+ // typicalLatency is the typical latency for an ARP or NDP packet to travel
+ // to a router and back.
+ typicalLatency = time.Millisecond
+
+ // testEntryBroadcastAddr is a special address that indicates a packet should
+ // be sent to all nodes.
+ testEntryBroadcastAddr = tcpip.Address("broadcast")
+
+ // testEntryLocalAddr is the source address of neighbor probes.
+ testEntryLocalAddr = tcpip.Address("local_addr")
+
+ // testEntryBroadcastLinkAddr is a special link address sent back to
+ // multicast neighbor probes.
+ testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast")
+
+ // infiniteDuration indicates that a task will not occur in our lifetime.
+ infiniteDuration = time.Duration(math.MaxInt64)
+)
+
+// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
+// entries. The UpdatedAt field is ignored due to a lack of a deterministic
+// method to predict the time that an event will be dispatched.
+func entryDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ }
+}
+
+// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
+// sort slices of entries for cases where ordering must be ignored.
+func entryDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
+ config.resetInvalidFields()
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ return &neighborCache{
+ nic: &NIC{
+ stack: &Stack{
+ clock: clock,
+ nudDisp: nudDisp,
+ },
+ id: 1,
+ },
+ state: NewNUDState(config, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+}
+
+// testEntryStore contains a set of IP to NeighborEntry mappings.
+type testEntryStore struct {
+ mu sync.RWMutex
+ entriesMap map[tcpip.Address]NeighborEntry
+}
+
+func toAddress(i int) tcpip.Address {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint16(i))
+ return tcpip.Address(buf.String())
+}
+
+func toLinkAddress(i int) tcpip.LinkAddress {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint32(i))
+ return tcpip.LinkAddress(buf.String())
+}
+
+// newTestEntryStore returns a testEntryStore pre-populated with entries.
+func newTestEntryStore() *testEntryStore {
+ store := &testEntryStore{
+ entriesMap: make(map[tcpip.Address]NeighborEntry),
+ }
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ linkAddr := toLinkAddress(i)
+
+ store.entriesMap[addr] = NeighborEntry{
+ Addr: addr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: linkAddr,
+ }
+ }
+ return store
+}
+
+// size returns the number of entries in the store.
+func (s *testEntryStore) size() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return len(s.entriesMap)
+}
+
+// entry returns the entry at index i. Returns an empty entry and false if i is
+// out of bounds.
+func (s *testEntryStore) entry(i int) (NeighborEntry, bool) {
+ return s.entryByAddr(toAddress(i))
+}
+
+// entryByAddr returns the entry matching addr for situations when the index is
+// not available. Returns an empty entry and false if no entries match addr.
+func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ entry, ok := s.entriesMap[addr]
+ return entry, ok
+}
+
+// entries returns all entries in the store.
+func (s *testEntryStore) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(s.entriesMap))
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ if entry, ok := s.entriesMap[addr]; ok {
+ entries = append(entries, entry)
+ }
+ }
+ return entries
+}
+
+// set modifies the link addresses of an entry.
+func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) {
+ addr := toAddress(i)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if entry, ok := s.entriesMap[addr]; ok {
+ entry.LinkAddr = linkAddr
+ s.entriesMap[addr] = entry
+ }
+}
+
+// testNeighborResolver implements LinkAddressResolver to emulate sending a
+// neighbor probe.
+type testNeighborResolver struct {
+ clock tcpip.Clock
+ neigh *neighborCache
+ entries *testEntryStore
+ delay time.Duration
+ onLinkAddressRequest func()
+}
+
+var _ LinkAddressResolver = (*testNeighborResolver)(nil)
+
+func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(addr)
+ })
+
+ // Execute post address resolution action, if available.
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
+ return nil
+}
+
+// fakeRequest emulates handling a response for a link address request.
+func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) {
+ if entry, ok := r.entries.entryByAddr(addr); ok {
+ r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
+}
+
+func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == testEntryBroadcastAddr {
+ return testEntryBroadcastLinkAddr, true
+ }
+ return "", false
+}
+
+func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 0
+}
+
+type entryEvent struct {
+ nicID tcpip.NICID
+ address tcpip.Address
+ linkAddr tcpip.LinkAddress
+ state NeighborState
+}
+
+func TestNeighborCacheGetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheSetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+ neigh.setConfig(c)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheEntry(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.Advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveEntry(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.Advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+}
+
+type testContext struct {
+ clock *faketime.ManualClock
+ neigh *neighborCache
+ store *testEntryStore
+ linkRes *testNeighborResolver
+ nudDisp *testNUDDispatcher
+}
+
+func newTestContext(c NUDConfigurations) testContext {
+ nudDisp := &testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ return testContext{
+ clock: clock,
+ neigh: neigh,
+ store: store,
+ linkRes: linkRes,
+ nudDisp: nudDisp,
+ }
+}
+
+type overflowOptions struct {
+ startAtEntryIndex int
+ wantStaticEntries []NeighborEntry
+}
+
+func (c *testContext) overflowCache(opts overflowOptions) error {
+ // Fill the neighbor cache to capacity to verify the LRU eviction strategy is
+ // working properly after the entry removal.
+ for i := opts.startAtEntryIndex; i < c.store.size(); i++ {
+ // Add a new entry
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
+
+ var wantEvents []testEntryEventInfo
+
+ // When beyond the full capacity, the cache will evict an entry as per the
+ // LRU eviction strategy. Note that the number of static entries should not
+ // affect the total number of dynamic entries that can be added.
+ if i >= neighborCacheSize+opts.startAtEntryIndex {
+ removedEntry, ok := c.store.entry(i - neighborCacheSize)
+ if !ok {
+ return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize)
+ }
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ })
+ }
+
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ }, testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ })
+
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the most recent entries. The order of entries reported
+ // by entries() is undeterministic, so entries have to be sorted before
+ // comparison.
+ wantUnsortedEntries := opts.wantStaticEntries
+ for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ return nil
+}
+
+// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy
+// respects the dynamic entry count.
+func TestNeighborCacheOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when an entry is removed.
+func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the entry
+ c.neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that
+// adding a duplicate static entry with the same link address does not dispatch
+// any events.
+func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that
+// adding a duplicate static entry with a different link address dispatches a
+// change event.
+func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a duplicate entry with a different link address
+ staticLinkAddr += "duplicate"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when a static entry is
+// added then removed. In this case, the dynamic entry count shouldn't have
+// been touched.
+func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.removeEntry(entry.Addr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU
+// cache eviction strategy keeps count of the dynamic entry count when an entry
+// is overwritten by a static entry. Static entries should not count towards
+// the size of the LRU cache.
+func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.Advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Override the entry with a static one using the same address
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheNotifiesWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ clock.Advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ id, ok := s.Fetch(false /* block */)
+ if !ok {
+ t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ if id != wakerID {
+ t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Remove the waker before the neighbor cache has the opportunity to send a
+ // notification.
+ neigh.removeWaker(entry.Addr, &w)
+ clock.Advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ if id, ok := s.Fetch(false /* block */); ok {
+ t.Errorf("unexpected notification from waker with id %d", id)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
+ e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheClear(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add a dynamic entry.
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a static entry.
+ neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Clear shoud remove both dynamic and static entries.
+ neigh.clear()
+
+ // Remove events dispatched from clear() have no deterministic order so they
+ // need to be sorted beforehand.
+ wantUnsortedEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction
+// strategy keeps count of the dynamic entry count when all entries are
+// cleared.
+func TestNeighborCacheClearThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.Advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Clear the cache.
+ c.neigh.clear()
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ frequentlyUsedEntry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+
+ // The following logic is very similar to overflowCache, but
+ // periodically refreshes the frequently used entry.
+
+ // Fill the neighbor cache to capacity
+ for i := 0; i < neighborCacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Keep adding more entries
+ for i := neighborCacheSize; i < store.size(); i++ {
+ // Periodically refresh the frequently used entry
+ if i%(neighborCacheSize/2) == 0 {
+ _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ }
+ }
+
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // An entry should have been removed, as per the LRU eviction strategy
+ removedEntry, ok := store.entry(i - neighborCacheSize + 1)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the frequently used entry and the most recent entries.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ wantUnsortedEntries := []NeighborEntry{
+ {
+ Addr: frequentlyUsedEntry.Addr,
+ LocalAddr: frequentlyUsedEntry.LocalAddr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
+ },
+ }
+
+ for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheConcurrent(t *testing.T) {
+ const concurrentProcesses = 16
+
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ storeEntries := store.entries()
+ for _, entry := range storeEntries {
+ var wg sync.WaitGroup
+ for r := 0; r < concurrentProcesses; r++ {
+ wg.Add(1)
+ go func(entry NeighborEntry) {
+ defer wg.Done()
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ }
+ }(entry)
+ }
+
+ // Wait for all gorountines to send a request
+ wg.Wait()
+
+ // Process all the requests for a single entry concurrently
+ clock.Advance(typicalLatency)
+ }
+
+ // All goroutines add in the same order and add more values than can fit in
+ // the cache. Our eviction strategy requires that the last entries are
+ // present, up to the size of the neighbor cache, and the rest are missing.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ var wantUnsortedEntries []NeighborEntry
+ for i := store.size() - neighborCacheSize; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Errorf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheReplace(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add an entry
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Verify the entry exists
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
+
+ // Notify of a link address change
+ var updatedLinkAddr tcpip.LinkAddress
+ {
+ entry, ok := store.entry(1)
+ if !ok {
+ t.Fatalf("store.entry(1) not found")
+ }
+ updatedLinkAddr = entry.LinkAddr
+ }
+ store.set(0, updatedLinkAddr)
+ neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+
+ // Requesting the entry again should start address resolution
+ {
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ }
+
+ // Verify the entry's new link address
+ {
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ clock.Advance(typicalLatency)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want = NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ }
+}
+
+func TestNeighborCacheResolutionFailed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+
+ var requestCount uint32
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ onLinkAddressRequest: func() {
+ atomic.AddUint32(&requestCount, 1)
+ },
+ }
+
+ // First, sanity check that resolution is working
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+
+ // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ before := atomic.LoadUint32(&requestCount)
+
+ entry.Addr += "2"
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+
+ maxAttempts := neigh.config().MaxUnicastProbes
+ if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
+}
+
+// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes
+// probes and not retrieving a confirmation before the duration defined by
+// MaxMulticastProbes * RetransmitTimer.
+func TestNeighborCacheResolutionTimeout(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ config.RetransmitTimer = time.Millisecond // small enough to cause timeout
+
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: time.Minute, // large enough to cause timeout
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+}
+
+// TestNeighborCacheStaticResolution checks that static link addresses are
+// resolved immediately and don't send resolution requests.
+func TestNeighborCacheStaticResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: testEntryBroadcastAddr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ }
+}
+
+func BenchmarkCacheClear(b *testing.B) {
+ b.StopTimer()
+ config := DefaultNUDConfigurations()
+ clock := &tcpip.StdClock{}
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: 0,
+ }
+
+ // Clear for every possible size of the cache
+ for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ {
+ // Fill the neighbor cache to capacity.
+ for i := 0; i < cacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ b.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh != nil {
+ <-doneCh
+ }
+ }
+
+ b.StartTimer()
+ neigh.clear()
+ b.StopTimer()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
new file mode 100644
index 000000000..9a72bec79
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -0,0 +1,490 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// NeighborEntry describes a neighboring device in the local network.
+type NeighborEntry struct {
+ Addr tcpip.Address
+ LocalAddr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+// NeighborState defines the state of a NeighborEntry within the Neighbor
+// Unreachability Detection state machine, as per RFC 4861 section 7.3.2.
+type NeighborState uint8
+
+const (
+ // Unknown means reachability has not been verified yet. This is the initial
+ // state of entries that have been created automatically by the Neighbor
+ // Unreachability Detection state machine.
+ Unknown NeighborState = iota
+ // Incomplete means that there is an outstanding request to resolve the
+ // address.
+ Incomplete
+ // Reachable means the path to the neighbor is functioning properly for both
+ // receive and transmit paths.
+ Reachable
+ // Stale means reachability to the neighbor is unknown, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Stale
+ // Delay means reachability to the neighbor is unknown and pending
+ // confirmation from an upper-level protocol like TCP, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Delay
+ // Probe means a reachability confirmation is actively being sought by
+ // periodically retransmitting reachability probes until a reachability
+ // confirmation is received, or until the max amount of probes has been sent.
+ Probe
+ // Static describes entries that have been explicitly added by the user. They
+ // do not expire and are not deleted until explicitly removed.
+ Static
+ // Failed means traffic should not be sent to this neighbor since attempts of
+ // reachability have returned inconclusive.
+ Failed
+)
+
+// neighborEntry implements a neighbor entry's individual node behavior, as per
+// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
+// parallel with the sending of packets to a neighbor, necessitating the
+// entry's lock to be acquired for all operations.
+type neighborEntry struct {
+ neighborEntryEntry
+
+ nic *NIC
+
+ // linkRes provides the functionality to send reachability probes, used in
+ // Neighbor Unreachability Detection.
+ linkRes LinkAddressResolver
+
+ // nudState points to the Neighbor Unreachability Detection configuration.
+ nudState *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ neigh NeighborEntry
+
+ // wakers is a set of waiters for address resolution result. Anytime state
+ // transitions out of incomplete these waiters are notified. It is nil iff
+ // address resolution is ongoing and no clients are waiting for the result.
+ wakers map[*sleep.Waker]struct{}
+
+ // done is used to allow callers to wait on address resolution. It is nil
+ // iff nudState is not Reachable and address resolution is not yet in
+ // progress.
+ done chan struct{}
+
+ isRouter bool
+ job *tcpip.Job
+}
+
+// newNeighborEntry creates a neighbor cache entry starting at the default
+// state, Unknown. Transition out of Unknown by calling either
+// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
+// neighborEntry.
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+ return &neighborEntry{
+ nic: nic,
+ linkRes: linkRes,
+ nudState: nudState,
+ neigh: NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ State: Unknown,
+ },
+ }
+}
+
+// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
+// state. The entry can only transition out of Static by directly calling
+// `setStateLocked`.
+func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ if nic.stack.nudDisp != nil {
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ }
+ return &neighborEntry{
+ nic: nic,
+ nudState: state,
+ neigh: NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ },
+ }
+}
+
+// addWaker adds w to the list of wakers waiting for address resolution.
+// Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
+ if w == nil {
+ return
+ }
+ if e.wakers == nil {
+ e.wakers = make(map[*sleep.Waker]struct{})
+ }
+ e.wakers[w] = struct{}{}
+}
+
+// notifyWakersLocked notifies those waiting for address resolution, whether it
+// succeeded or failed. Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) notifyWakersLocked() {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
+// been added.
+func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
+// has changed state or link-layer address.
+func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
+// has been removed.
+func (e *neighborEntry) dispatchRemoveEventLocked() {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ }
+}
+
+// setStateLocked transitions the entry to the specified state immediately.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// e.mu MUST be locked.
+func (e *neighborEntry) setStateLocked(next NeighborState) {
+ // Cancel the previously scheduled action, if there is one. Entries in
+ // Unknown, Stale, or Static state do not have scheduled actions.
+ if timer := e.job; timer != nil {
+ timer.Cancel()
+ }
+
+ prev := e.neigh.State
+ e.neigh.State = next
+ e.neigh.UpdatedAt = time.Now()
+ config := e.nudState.Config()
+
+ switch next {
+ case Incomplete:
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendMulticastProbe()
+
+ case Reachable:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ })
+ e.job.Schedule(e.nudState.ReachableTime())
+
+ case Delay:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Probe)
+ e.setStateLocked(Probe)
+ })
+ e.job.Schedule(config.DelayFirstProbeTime)
+
+ case Probe:
+ var retryCounter uint32
+ var sendUnicastProbe func()
+
+ sendUnicastProbe = func() {
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendUnicastProbe()
+
+ case Failed:
+ e.notifyWakersLocked()
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.nic.neigh.removeEntryLocked(e)
+ })
+ e.job.Schedule(config.UnreachableTime)
+
+ case Unknown, Stale, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next))
+ }
+}
+
+// handlePacketQueuedLocked advances the state machine according to a packet
+// being queued for outgoing transmission.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+func (e *neighborEntry) handlePacketQueuedLocked() {
+ switch e.neigh.State {
+ case Unknown:
+ e.dispatchAddEventLocked(Incomplete)
+ e.setStateLocked(Incomplete)
+
+ case Stale:
+ e.dispatchChangeEventLocked(Delay)
+ e.setStateLocked(Delay)
+
+ case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or
+// Neighbor Solicitation for ARP or NDP, respectively).
+//
+// Follows the logic defined in RFC 4861 section 7.2.3.
+func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
+ // Probes MUST be silently discarded if the target address is tentative, does
+ // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
+ // checks MUST be done by the NetworkEndpoint.
+
+ switch e.neigh.State {
+ case Unknown, Incomplete, Failed:
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchAddEventLocked(Stale)
+ e.setStateLocked(Stale)
+ e.notifyWakersLocked()
+
+ case Reachable, Delay, Probe:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+
+ case Stale:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ }
+
+ case Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleConfirmationLocked processes an incoming neighbor confirmation
+// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively).
+//
+// Follows the state machine defined by RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce Cryptographically
+// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
+// claimed source of an NDP message is the owner of the claimed address.
+func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ switch e.neigh.State {
+ case Incomplete:
+ if len(linkAddr) == 0 {
+ // "If the link layer has addresses and no Target Link-Layer Address
+ // option is included, the receiving node SHOULD silently discard the
+ // received advertisement." - RFC 4861 section 7.2.5
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+ if flags.Solicited {
+ e.dispatchChangeEventLocked(Reachable)
+ e.setStateLocked(Reachable)
+ } else {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ e.isRouter = flags.IsRouter
+ e.notifyWakersLocked()
+
+ // "Note that the Override flag is ignored if the entry is in the
+ // INCOMPLETE state." - RFC 4861 section 7.2.5
+
+ case Reachable, Stale, Delay, Probe:
+ sameLinkAddr := e.neigh.LinkAddr == linkAddr
+
+ if !sameLinkAddr {
+ if !flags.Override {
+ if e.neigh.State == Reachable {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+
+ if !flags.Solicited {
+ if e.neigh.State != Stale {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ } else {
+ // Notify the LinkAddr change, even though NUD state hasn't changed.
+ e.dispatchChangeEventLocked(e.neigh.State)
+ }
+ break
+ }
+ }
+
+ if flags.Solicited && (flags.Override || sameLinkAddr) {
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ }
+ // Set state to Reachable again to refresh timers.
+ e.setStateLocked(Reachable)
+ e.notifyWakersLocked()
+ }
+
+ if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
+ // "In those cases where the IsRouter flag changes from TRUE to FALSE as
+ // a result of this update, the node MUST remove that router from the
+ // Default Router List and update the Destination Cache entries for all
+ // destinations using that neighbor as a router as specified in Section
+ // 7.3.3. This is needed to detect when a node that is used as a router
+ // stops forwarding packets due to being configured as a host."
+ // - RFC 4861 section 7.2.5
+ //
+ // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6
+ // here.
+ ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint"))
+ }
+
+ if ndpEP, ok := ep.(NDPEndpoint); ok {
+ ndpEP.InvalidateDefaultRouter(e.neigh.Addr)
+ }
+ }
+ e.isRouter = flags.IsRouter
+
+ case Unknown, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
+// (e.g. TCP acknowledgements) reachability confirmation.
+func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
+ switch e.neigh.State {
+ case Reachable, Stale, Delay, Probe:
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ // Set state to Reachable again to refresh timers.
+ }
+ e.setStateLocked(Reachable)
+
+ case Unknown, Incomplete, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
new file mode 100644
index 000000000..a265fff0a
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -0,0 +1,2869 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+
+ entryTestNICID tcpip.NICID = 1
+ entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
+ entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
+
+ // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match the
+ // MTU of loopback interfaces on Linux systems.
+ entryTestNetDefaultMTU = 65536
+)
+
+// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
+// The UpdatedAt field is ignored due to a lack of a deterministic method to
+// predict the time that an event will be dispatched.
+func eventDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ }
+}
+
+// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
+// sort slices of events for cases where ordering must be ignored.
+func eventDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+// The following unit tests exercise every state transition and verify its
+// behavior with RFC 4681.
+//
+// | From | To | Cause | Action | Event |
+// | ========== | ========== | ========================================== | =============== | ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
+// | Unknown | Stale | Probe w/ unknown address | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
+// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
+// | Reachable | Stale | Reachable timer expired | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
+// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
+// | Stale | Delay | Packet sent | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | Changed |
+// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
+// | Delay | Probe | Delay timer expired | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
+// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
+// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | | Unreachability timer expired | Delete entry | |
+
+type testEntryEventType uint8
+
+const (
+ entryTestAdded testEntryEventType = iota
+ entryTestChanged
+ entryTestRemoved
+)
+
+func (t testEntryEventType) String() string {
+ switch t {
+ case entryTestAdded:
+ return "add"
+ case entryTestChanged:
+ return "change"
+ case entryTestRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+// Fields are exported for use with cmp.Diff.
+type testEntryEventInfo struct {
+ EventType testEntryEventType
+ NICID tcpip.NICID
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+func (e testEntryEventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+}
+
+// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type testNUDDispatcher struct {
+ mu sync.Mutex
+ events []testEntryEventInfo
+}
+
+var _ NUDDispatcher = (*testNUDDispatcher)(nil)
+
+func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.events = append(d.events, e)
+}
+
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+type entryTestLinkResolver struct {
+ mu sync.Mutex
+ probes []entryTestProbeInfo
+}
+
+var _ LinkAddressResolver = (*entryTestLinkResolver)(nil)
+
+type entryTestProbeInfo struct {
+ RemoteAddress tcpip.Address
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalAddress tcpip.Address
+}
+
+func (p entryTestProbeInfo) String() string {
+ return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress)
+}
+
+// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+// to the local network if linkAddr is the zero value.
+func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ p := entryTestProbeInfo{
+ RemoteAddress: addr,
+ RemoteLinkAddress: linkAddr,
+ LocalAddress: localAddr,
+ }
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.probes = append(r.probes, p)
+ return nil
+}
+
+// ResolveStaticAddress attempts to resolve address without sending requests.
+// It either resolves the name immediately or returns the empty LinkAddress.
+func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ return "", false
+}
+
+// LinkAddressProtocol returns the network protocol of the addresses this
+// resolver can resolve.
+func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return entryTestNetNumber
+}
+
+func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) {
+ clock := faketime.NewManualClock()
+ disp := testNUDDispatcher{}
+ nic := NIC{
+ id: entryTestNICID,
+ linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ stack: &Stack{
+ clock: clock,
+ nudDisp: &disp,
+ },
+ }
+ nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
+ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
+ }
+
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ nudState := NewNUDState(c, rng)
+ linkRes := entryTestLinkResolver{}
+ entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
+
+ // Stub out the neighbor cache to verify deletion from the cache.
+ nic.neigh = &neighborCache{
+ nic: &nic,
+ state: nudState,
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ nic.neigh.cache[entryTestAddr1] = entry
+
+ return entry, &disp, &linkRes, clock
+}
+
+// TestEntryInitiallyUnknown verifies that the state of a newly created
+// neighborEntry is Unknown.
+func TestEntryInitiallyUnknown(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.RetransmitTimer)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(time.Hour)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToIncomplete(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ {
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+func TestEntryUnknownToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ updatedAt := e.neigh.UpdatedAt
+ e.mu.Unlock()
+
+ clock.Advance(c.RetransmitTimer)
+
+ // UpdatedAt should remain the same during address resolution.
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.UpdatedAt, updatedAt; got != want {
+ t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.RetransmitTimer)
+
+ // UpdatedAt should change after failing address resolution. Timing out after
+ // sending the last probe transitions the entry to Failed.
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ clock.Advance(c.RetransmitTimer)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant {
+ t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachable(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+// TestEntryAddsAndClearsWakers verifies that wakers are added when
+// addWakerLocked is called and cleared when address resolution finishes. In
+// this case, address resolution will finish when transitioning from Incomplete
+// to Reachable.
+func TestEntryAddsAndClearsWakers(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got := e.wakers; got != nil {
+ t.Errorf("got e.wakers = %v, want = nil", got)
+ }
+ e.addWakerLocked(&w)
+ if got, want := w.IsAsserted(), false; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ if e.wakers == nil {
+ t.Error("expected e.wakers to be non-nil")
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.wakers != nil {
+ t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ }
+ if got, want := w.IsAsserted(), true; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ linkRes.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+type testLocker struct{}
+
+var _ sync.Locker = (*testLocker)(nil)
+
+func (*testLocker) Lock() {}
+func (*testLocker) Unlock() {}
+
+func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.isRouter, false; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ if ipv6EP.invalidatedRtr != e.neigh.Addr {
+ t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToDelay(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 1
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToProbe(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario:
+// 1. Probe is received
+// 2. Entry is created in Stale
+// 3. Packet is queued on the entry
+// 4. Entry transitions to Delay then Probe
+// 5. Probe is sent
+func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Probe to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr1)
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // Probe caused by the Delay-to-Probe transition
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.Advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryFailedGetsDeleted(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
+ clock.Advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ // Verify the cache no longer contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
+ t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
+ }
+}
diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go
new file mode 100644
index 000000000..aa7311ec6
--- /dev/null
+++ b/pkg/tcpip/stack/neighborstate_string.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NeighborState"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Unknown-0]
+ _ = x[Incomplete-1]
+ _ = x[Reachable-2]
+ _ = x[Stale-3]
+ _ = x[Delay-4]
+ _ = x[Probe-5]
+ _ = x[Static-6]
+ _ = x[Failed-7]
+}
+
+const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed"
+
+var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53}
+
+func (i NeighborState) String() string {
+ if i >= NeighborState(len(_NeighborState_index)-1) {
+ return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index afb7dfeaf..6cf54cc89 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -16,24 +16,18 @@ package stack
import (
"fmt"
+ "math/rand"
"reflect"
- "sort"
- "strings"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-var ipv4BroadcastAddr = tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 8 * header.IPv4AddressSize,
- },
-}
+var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
@@ -45,20 +39,24 @@ type NIC struct {
context NICContext
stats NICStats
+ neigh *neighborCache
+
+ // The network endpoints themselves may be modified by calling the interface's
+ // methods, but the map reference and entries must be constant.
+ networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
+
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
mu struct {
sync.RWMutex
- enabled bool
- spoofing bool
- promiscuous bool
- primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- addressRanges []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]uint32
+ spoofing bool
+ promiscuous bool
// packetEPs is protected by mu, but the contained PacketEndpoint
// values are not.
packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
- ndp ndpState
}
}
@@ -82,25 +80,6 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
-// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
-type PrimaryEndpointBehavior int
-
-const (
- // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
- CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
-
- // FirstPrimaryEndpoint indicates the endpoint should be the first
- // primary endpoint considered. If there are multiple endpoints with
- // this behavior, the most recently-added one will be first.
- FirstPrimaryEndpoint
-
- // NeverPrimaryEndpoint indicates the endpoint should never be a
- // primary endpoint.
- NeverPrimaryEndpoint
-)
-
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -112,33 +91,43 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- context: ctx,
- stats: makeNICStats(),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ context: ctx,
+ stats: makeNICStats(),
+ networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
- nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
- nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
- nic.mu.ndp = ndpState{
- nic: nic,
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+
+ // Check for Neighbor Unreachability Detection support.
+ var nud NUDHandler
+ if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache {
+ rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds()))
+ nic.neigh = &neighborCache{
+ nic: nic,
+ state: NewNUDState(stack.nudConfigs, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+
+ // An interface value that holds a nil pointer but non-nil type is not the
+ // same as the nil interface. Because of this, nud must only be assignd if
+ // nic.neigh is non-nil since a nil reference to a neighborCache is not
+ // valid.
+ //
+ // See https://golang.org/doc/faq#nil_error for more information.
+ nud = nic.neigh
}
- nic.mu.ndp.initializeTempAddrState()
- // Register supported packet endpoint protocols.
+ // Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
nic.mu.packetEPs[netProto] = []PacketEndpoint{}
}
for _, netProto := range stack.networkProtocols {
- nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ netNum := netProto.Number()
+ nic.mu.packetEPs[netNum] = nil
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
nic.linkEP.Attach(nic)
@@ -146,29 +135,32 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
return nic
}
-// enabled returns true if n is enabled.
-func (n *NIC) enabled() bool {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- return enabled
+func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
+ return n.networkEndpoints[proto]
}
-// disable disables n.
+// Enabled implements NetworkInterface.
+func (n *NIC) Enabled() bool {
+ return atomic.LoadUint32(&n.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the NIC.
//
-// It undoes the work done by enable.
-func (n *NIC) disable() *tcpip.Error {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- if !enabled {
- return nil
+// Returns true if the enabled status was updated.
+func (n *NIC) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&n.enabled, 1) == 0
}
+ return atomic.SwapUint32(&n.enabled, 0) == 1
+}
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() {
n.mu.Lock()
- err := n.disableLocked()
+ n.disableLocked()
n.mu.Unlock()
- return err
}
// disableLocked disables n.
@@ -176,43 +168,19 @@ func (n *NIC) disable() *tcpip.Error {
// It undoes the work done by enable.
//
// n MUST be locked.
-func (n *NIC) disableLocked() *tcpip.Error {
- if !n.mu.enabled {
- return nil
+func (n *NIC) disableLocked() {
+ if !n.setEnabled(false) {
+ return
}
- // TODO(b/147015577): Should Routes that are currently bound to n be
+ // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be
// invalidated? Currently, Routes will continue to work when a NIC is enabled
// again, and applications may not know that the underlying NIC was ever
// disabled.
- if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
- n.mu.ndp.stopSolicitingRouters()
- n.mu.ndp.cleanupState(false /* hostOnly */)
-
- // Stop DAD for all the unicast IPv6 endpoints that are in the
- // permanentTentative state.
- for _, r := range n.mu.endpoints {
- if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
- n.mu.ndp.stopDuplicateAddressDetection(addr)
- }
- }
-
- // The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
- }
-
- if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- // The address may have already been removed.
- if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
+ for _, ep := range n.networkEndpoints {
+ ep.Disable()
}
-
- n.mu.enabled = false
- return nil
}
// enable enables n.
@@ -222,150 +190,38 @@ func (n *NIC) disableLocked() *tcpip.Error {
// routers if the stack is not operating as a router. If the stack is also
// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
- n.mu.RLock()
- enabled := n.mu.enabled
- n.mu.RUnlock()
- if enabled {
- return nil
- }
-
n.mu.Lock()
defer n.mu.Unlock()
- if n.mu.enabled {
- return nil
- }
-
- n.mu.enabled = true
-
- // Create an endpoint to receive broadcast packets on this interface.
- if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
- return err
- }
- }
-
- // Join the IPv6 All-Nodes Multicast group if the stack is configured to
- // use IPv6. This is required to ensure that this node properly receives
- // and responds to the various NDP messages that are destined to the
- // all-nodes multicast address. An example is the Neighbor Advertisement
- // when we perform Duplicate Address Detection, or Router Advertisement
- // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
- // section 4.2 for more information.
- //
- // Also auto-generate an IPv6 link-local address based on the NIC's
- // link address if it is configured to do so. Note, each interface is
- // required to have IPv6 link-local unicast address, as per RFC 4291
- // section 2.1.
- _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
- if !ok {
+ if !n.setEnabled(true) {
return nil
}
- // Join the All-Nodes multicast group before starting DAD as responses to DAD
- // (NDP NS) messages may be sent to the All-Nodes multicast group if the
- // source address of the NDP NS is the unspecified address, as per RFC 4861
- // section 7.2.4.
- if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
- return err
- }
-
- // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
- // state.
- //
- // Addresses may have aleady completed DAD but in the time since the NIC was
- // last enabled, other devices may have acquired the same addresses.
- for _, r := range n.mu.endpoints {
- addr := r.ep.ID().LocalAddress
- if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
- continue
- }
-
- r.setKind(permanentTentative)
- if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
+ for _, ep := range n.networkEndpoints {
+ if err := ep.Enable(); err != nil {
return err
}
}
- // Do not auto-generate an IPv6 link-local address for loopback devices.
- if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
- // The valid and preferred lifetime is infinite for the auto-generated
- // link-local address.
- n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
- }
-
- // If we are operating as a router, then do not solicit routers since we
- // won't process the RAs anyways.
- //
- // Routers do not process Router Advertisements (RA) the same way a host
- // does. That is, routers do not learn from RAs (e.g. on-link prefixes
- // and default routers). Therefore, soliciting RAs from other routers on
- // a link is unnecessary for routers.
- if !n.stack.forwarding {
- n.mu.ndp.startSolicitingRouters()
- }
-
return nil
}
-// remove detaches NIC from the link endpoint, and marks existing referenced
-// network endpoints expired. This guarantees no packets between this NIC and
-// the network stack.
+// remove detaches NIC from the link endpoint and releases network endpoint
+// resources. This guarantees no packets between this NIC and the network
+// stack.
func (n *NIC) remove() *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
n.disableLocked()
- // TODO(b/151378115): come up with a better way to pick an error than the
- // first one.
- var err *tcpip.Error
-
- // Forcefully leave multicast groups.
- for nid := range n.mu.mcastJoins {
- if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
- err = tempErr
- }
- }
-
- // Remove permanent and permanentTentative addresses, so no packet goes out.
- for nid, ref := range n.mu.endpoints {
- switch ref.getKind() {
- case permanentTentative, permanent:
- if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
- err = tempErr
- }
- }
+ for _, ep := range n.networkEndpoints {
+ ep.Close()
}
// Detach from link endpoint, so no packet comes in.
n.linkEP.Attach(nil)
-
- return err
-}
-
-// becomeIPv6Router transitions n into an IPv6 router.
-//
-// When transitioning into an IPv6 router, host-only state (NDP discovered
-// routers, discovered on-link prefixes, and auto-generated addresses) will
-// be cleaned up/invalidated and NDP router solicitations will be stopped.
-func (n *NIC) becomeIPv6Router() {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.cleanupState(true /* hostOnly */)
- n.mu.ndp.stopSolicitingRouters()
-}
-
-// becomeIPv6Host transitions n into an IPv6 host.
-//
-// When transitioning into an IPv6 host, NDP router solicitations will be
-// started.
-func (n *NIC) becomeIPv6Host() {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.startSolicitingRouters()
+ return nil
}
// setPromiscuousMode enables or disables promiscuous mode.
@@ -382,7 +238,8 @@ func (n *NIC) isPromiscuousMode() bool {
return rv
}
-func (n *NIC) isLoopback() bool {
+// IsLoopback implements NetworkInterface.
+func (n *NIC) IsLoopback() bool {
return n.linkEP.Capabilities()&CapabilityLoopback != 0
}
@@ -393,213 +250,53 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-// primaryEndpoint will return the first non-deprecated endpoint if such an
-// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
-// endpoint exists, the first deprecated endpoint will be returned.
-//
-// If an IPv6 primary endpoint is requested, Source Address Selection (as
-// defined by RFC 6724 section 5) will be performed.
-func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- if protocol == header.IPv6ProtocolNumber && remoteAddr != "" {
- return n.primaryIPv6Endpoint(remoteAddr)
- }
-
+// primaryAddress returns an address that can be used to communicate with
+// remoteAddr.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
-
- var deprecatedEndpoint *referencedNetworkEndpoint
- for _, r := range n.mu.primary[protocol] {
- if !r.isValidForOutgoingRLocked() {
- continue
- }
-
- if !r.deprecated {
- if r.tryIncRef() {
- // r is not deprecated, so return it immediately.
- //
- // If we kept track of a deprecated endpoint, decrement its reference
- // count since it was incremented when we decided to keep track of it.
- if deprecatedEndpoint != nil {
- deprecatedEndpoint.decRefLocked()
- deprecatedEndpoint = nil
- }
-
- return r
- }
- } else if deprecatedEndpoint == nil && r.tryIncRef() {
- // We prefer an endpoint that is not deprecated, but we keep track of r in
- // case n doesn't have any non-deprecated endpoints.
- //
- // If we end up finding a more preferred endpoint, r's reference count
- // will be decremented when such an endpoint is found.
- deprecatedEndpoint = r
- }
- }
-
- // n doesn't have any valid non-deprecated endpoints, so return
- // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
- // endpoints either).
- return deprecatedEndpoint
-}
-
-// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC
-// 6724 section 5).
-type ipv6AddrCandidate struct {
- ref *referencedNetworkEndpoint
- scope header.IPv6AddressScope
-}
-
-// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address
-// Selection (RFC 6724 section 5).
-//
-// Note, only rules 1-3 and 7 are followed.
-//
-// remoteAddr must be a valid IPv6 address.
-func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- n.mu.RLock()
- ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ spoofing := n.mu.spoofing
n.mu.RUnlock()
- return ref
-}
-
-// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
-// Selection (RFC 6724 section 5).
-//
-// Note, only rules 1-3 and 7 are followed.
-//
-// remoteAddr must be a valid IPv6 address.
-//
-// n.mu MUST be read locked.
-func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
-
- if len(primaryAddrs) == 0 {
- return nil
- }
-
- // Create a candidate set of available addresses we can potentially use as a
- // source address.
- cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
- for _, r := range primaryAddrs {
- // If r is not valid for outgoing connections, it is not a valid endpoint.
- if !r.isValidForOutgoingRLocked() {
- continue
- }
-
- addr := r.ep.ID().LocalAddress
- scope, err := header.ScopeForIPv6Address(addr)
- if err != nil {
- // Should never happen as we got r from the primary IPv6 endpoint list and
- // ScopeForIPv6Address only returns an error if addr is not an IPv6
- // address.
- panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
- }
-
- cs = append(cs, ipv6AddrCandidate{
- ref: r,
- scope: scope,
- })
- }
-
- remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
- if err != nil {
- // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
- panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
- }
-
- // Sort the addresses as per RFC 6724 section 5 rules 1-3.
- //
- // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
- sort.Slice(cs, func(i, j int) bool {
- sa := cs[i]
- sb := cs[j]
-
- // Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.ref.ep.ID().LocalAddress == remoteAddr {
- return true
- }
- if sb.ref.ep.ID().LocalAddress == remoteAddr {
- return false
- }
-
- // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
- if sa.scope < sb.scope {
- return sa.scope >= remoteScope
- } else if sb.scope < sa.scope {
- return sb.scope < remoteScope
- }
-
- // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
- if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep {
- // If sa is not deprecated, it is preferred over sb.
- return sbDep
- }
-
- // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
- if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp {
- return saTemp
- }
-
- // sa and sb are equal, return the endpoint that is closest to the front of
- // the primary endpoint list.
- return i < j
- })
-
- // Return the most preferred address that can have its reference count
- // incremented.
- for _, c := range cs {
- if r := c.ref; r.tryIncRef() {
- return r
- }
- }
-
- return nil
-}
-
-// hasPermanentAddrLocked returns true if n has a permanent (including currently
-// tentative) address, addr.
-func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+ ep, ok := n.networkEndpoints[protocol]
if !ok {
- return false
+ return nil
}
- kind := ref.getKind()
-
- return kind == permanent || kind == permanentTentative
+ return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
-type getRefBehaviour int
+type getAddressBehaviour int
const (
// spoofing indicates that the NIC's spoofing flag should be observed when
- // getting a NIC's referenced network endpoint.
- spoofing getRefBehaviour = iota
+ // getting a NIC's address endpoint.
+ spoofing getAddressBehaviour = iota
// promiscuous indicates that the NIC's promiscuous flag should be observed
- // when getting a NIC's referenced network endpoint.
+ // when getting a NIC's address endpoint.
promiscuous
)
-func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
+func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint {
+ return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
// findEndpoint finds the endpoint, if any, with the given address.
-func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, address, peb, spoofing)
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+ return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
}
-// getRefEpOrCreateTemp returns the referenced network endpoint for the given
-// protocol and address.
+// getAddressEpOrCreateTemp returns the address endpoint for the given protocol
+// and address.
//
// If none exists a temporary one may be created if we are in promiscuous mode
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
// Similarly, spoofing will only be checked if spoofing is true.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
+//
+// If the address is the IPv4 broadcast address for an endpoint's network, that
+// endpoint will be returned.
+func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint {
n.mu.RLock()
-
var spoofingOrPromiscuous bool
switch tempRef {
case spoofing:
@@ -607,267 +304,54 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
case promiscuous:
spoofingOrPromiscuous = n.mu.promiscuous
}
-
- if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
- // An endpoint with this id exists, check if it can be used and return it.
- if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
- n.mu.RUnlock()
- return nil
- }
-
- if ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
- }
-
- // A usable reference was not found, create a temporary one if requested by
- // the caller or if the address is found in the NIC's subnets.
- createTempEP := spoofingOrPromiscuous
- if !createTempEP {
- for _, sn := range n.mu.addressRanges {
- // Skip the subnet address.
- if address == sn.ID() {
- continue
- }
- // For now just skip the broadcast address, until we support it.
- // FIXME(b/137608825): Add support for sending/receiving directed
- // (subnet) broadcast.
- if address == sn.Broadcast() {
- continue
- }
- if sn.Contains(address) {
- createTempEP = true
- break
- }
- }
- }
-
n.mu.RUnlock()
-
- if !createTempEP {
- return nil
- }
-
- // Try again with the lock in exclusive mode. If we still can't get the
- // endpoint, create a new "temporary" endpoint. It will only exist while
- // there's a route through it.
- n.mu.Lock()
- ref := n.getRefOrCreateTempLocked(protocol, address, peb)
- n.mu.Unlock()
- return ref
+ return n.getAddressOrCreateTempInner(protocol, address, spoofingOrPromiscuous, peb)
}
-/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
-/// and returns a temporary endpoint.
-func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
- // No need to check the type as we are ok with expired endpoints at this
- // point.
- if ref.tryIncRef() {
- return ref
- }
- // tryIncRef failing means the endpoint is scheduled to be removed once the
- // lock is released. Remove it here so we can create a new (temporary) one.
- // The removal logic waiting for the lock handles this case.
- n.removeEndpointLocked(ref)
+// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
+// is passed to indicate whether or not we should generate temporary endpoints.
+func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+ if ep, ok := n.networkEndpoints[protocol]; ok {
+ return ep.AcquireAssignedAddress(address, createTemp, peb)
}
- // Add a new temporary endpoint.
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- return nil
- }
- ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb, temporary, static, false)
- return ref
+ return nil
}
-// addAddressLocked adds a new protocolAddress to n.
-//
-// If n already has the address in a non-permanent state, and the kind given is
-// permanent, that address will be promoted in place and its properties set to
-// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress.
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
- // TODO(b/141022673): Validate IP addresses before adding them.
-
- // Sanity check.
- id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address}
- if ref, ok := n.mu.endpoints[id]; ok {
- // Endpoint already exists.
- if kind != permanent {
- return nil, tcpip.ErrDuplicateAddress
- }
- switch ref.getKind() {
- case permanentTentative, permanent:
- // The NIC already have a permanent endpoint with that address.
- return nil, tcpip.ErrDuplicateAddress
- case permanentExpired, temporary:
- // Promote the endpoint to become permanent and respect the new peb,
- // configType and deprecated status.
- if ref.tryIncRef() {
- // TODO(b/147748385): Perform Duplicate Address Detection when promoting
- // an IPv6 endpoint to permanent.
- ref.setKind(permanent)
- ref.deprecated = deprecated
- ref.configType = configType
-
- refs := n.mu.primary[ref.protocol]
- for i, r := range refs {
- if r == ref {
- switch peb {
- case CanBePrimaryEndpoint:
- return ref, nil
- case FirstPrimaryEndpoint:
- if i == 0 {
- return ref, nil
- }
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- case NeverPrimaryEndpoint:
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- return ref, nil
- }
- }
- }
-
- n.insertPrimaryEndpointLocked(ref, peb)
-
- return ref, nil
- }
- // tryIncRef failing means the endpoint is scheduled to be removed once
- // the lock is released. Remove it here so we can create a new
- // (permanent) one. The removal logic waiting for the lock handles this
- // case.
- n.removeEndpointLocked(ref)
- }
- }
-
- netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
+// addAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
+ ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
- }
-
- // Create the new network endpoint.
- ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack)
- if err != nil {
- return nil, err
- }
-
- isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
-
- // If the address is an IPv6 address and it is a permanent address,
- // mark it as tentative so it goes through the DAD process if the NIC is
- // enabled. If the NIC is not enabled, DAD will be started when the NIC is
- // enabled.
- if isIPv6Unicast && kind == permanent {
- kind = permanentTentative
- }
-
- ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- kind: kind,
- configType: configType,
- deprecated: deprecated,
- }
-
- // Set up cache if link address resolution exists for this protocol.
- if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
- ref.linkCache = n.stack
- }
- }
-
- // If we are adding an IPv6 unicast address, join the solicited-node
- // multicast address.
- if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
- if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
- return nil, err
- }
+ return tcpip.ErrUnknownProtocol
}
- n.mu.endpoints[id] = ref
-
- n.insertPrimaryEndpointLocked(ref, peb)
-
- // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
- if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
- if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
- return nil, err
- }
+ addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ if err == nil {
+ // We have no need for the address endpoint.
+ addressEndpoint.DecRef()
}
-
- return ref, nil
-}
-
-// AddAddress adds a new address to n, so that it starts accepting packets
-// targeted at the given address (and network protocol).
-func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
- // Add the endpoint.
- n.mu.Lock()
- _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */)
- n.mu.Unlock()
-
return err
}
-// AllAddresses returns all addresses (primary and non-primary) associated with
+// allPermanentAddresses returns all permanent addresses associated with
// this NIC.
-func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
- for nid, ref := range n.mu.endpoints {
- // Don't include tentative, expired or temporary endpoints to
- // avoid confusion and prevent the caller from using those.
- switch ref.getKind() {
- case permanentExpired, temporary:
- continue
+func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
+ var addrs []tcpip.ProtocolAddress
+ for p, ep := range n.networkEndpoints {
+ for _, a := range ep.PermanentAddresses() {
+ addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
-
- addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: ref.protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: nid.LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- },
- })
}
return addrs
}
-// PrimaryAddresses returns the primary addresses associated with this NIC.
-func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
+// primaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
- for proto, list := range n.mu.primary {
- for _, ref := range list {
- // Don't include tentative, expired or tempory endpoints
- // to avoid confusion and prevent the caller from using
- // those.
- switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- continue
- }
-
- addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: proto,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: ref.ep.ID().LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- },
- })
+ for p, ep := range n.networkEndpoints {
+ for _, a := range ep.PrimaryAddresses() {
+ addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
return addrs
@@ -879,289 +363,135 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
// address exists. If no non-deprecated address exists, the first deprecated
// address will be returned.
func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- list, ok := n.mu.primary[proto]
+ ep, ok := n.networkEndpoints[proto]
if !ok {
return tcpip.AddressWithPrefix{}
}
- var deprecatedEndpoint *referencedNetworkEndpoint
- for _, ref := range list {
- // Don't include tentative, expired or tempory endpoints to avoid confusion
- // and prevent the caller from using those.
- switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- continue
- }
-
- if !ref.deprecated {
- return tcpip.AddressWithPrefix{
- Address: ref.ep.ID().LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- }
- }
-
- if deprecatedEndpoint == nil {
- deprecatedEndpoint = ref
- }
- }
-
- if deprecatedEndpoint != nil {
- return tcpip.AddressWithPrefix{
- Address: deprecatedEndpoint.ep.ID().LocalAddress,
- PrefixLen: deprecatedEndpoint.ep.PrefixLen(),
- }
- }
-
- return tcpip.AddressWithPrefix{}
+ return ep.MainAddress()
}
-// AddAddressRange adds a range of addresses to n, so that it starts accepting
-// packets targeted at the given addresses and network protocol. The range is
-// given by a subnet address, and all addresses contained in the subnet are
-// used except for the subnet address itself and the subnet's broadcast
-// address.
-func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
- n.mu.Lock()
- n.mu.addressRanges = append(n.mu.addressRanges, subnet)
- n.mu.Unlock()
-}
-
-// RemoveAddressRange removes the given address range from n.
-func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
- n.mu.Lock()
-
- // Use the same underlying array.
- tmp := n.mu.addressRanges[:0]
- for _, sub := range n.mu.addressRanges {
- if sub != subnet {
- tmp = append(tmp, sub)
+// removeAddress removes an address from n.
+func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
+ for _, ep := range n.networkEndpoints {
+ if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ continue
+ } else {
+ return err
}
}
- n.mu.addressRanges = tmp
- n.mu.Unlock()
+ return tcpip.ErrBadLocalAddress
}
-// AddressRanges returns the Subnets associated with this NIC.
-func (n *NIC) AddressRanges() []tcpip.Subnet {
- n.mu.RLock()
- defer n.mu.RUnlock()
- sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints))
- for nid := range n.mu.endpoints {
- sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
- if err != nil {
- // This should never happen as the mask has been carefully crafted to
- // match the address.
- panic("Invalid endpoint subnet: " + err.Error())
- }
- sns = append(sns, sn)
+func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
+ if n.neigh == nil {
+ return nil, tcpip.ErrNotSupported
}
- return append(sns, n.mu.addressRanges...)
-}
-// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
-// by peb.
-//
-// n MUST be locked.
-func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
- switch peb {
- case CanBePrimaryEndpoint:
- n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r)
- case FirstPrimaryEndpoint:
- n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...)
- }
+ return n.neigh.entries(), nil
}
-func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
- id := *r.ep.ID()
-
- // Nothing to do if the reference has already been replaced with a different
- // one. This happens in the case where 1) this endpoint's ref count hit zero
- // and was waiting (on the lock) to be removed and 2) the same address was
- // re-added in the meantime by removing this endpoint from the list and
- // adding a new one.
- if n.mu.endpoints[id] != r {
+func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
+ if n.neigh == nil {
return
}
- if r.getKind() == permanent {
- panic("Reference count dropped to zero before being removed")
- }
+ n.neigh.removeWaker(addr, w)
+}
- delete(n.mu.endpoints, id)
- refs := n.mu.primary[r.protocol]
- for i, ref := range refs {
- if ref == r {
- n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
- refs[len(refs)-1] = nil
- break
- }
+func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
}
- r.ep.Close()
-}
-
-func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
- n.mu.Lock()
- n.removeEndpointLocked(r)
- n.mu.Unlock()
+ n.neigh.addStaticEntry(addr, linkAddress)
+ return nil
}
-func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
- r, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return tcpip.ErrBadLocalAddress
- }
-
- kind := r.getKind()
- if kind != permanent && kind != permanentTentative {
- return tcpip.ErrBadLocalAddress
+func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
}
- switch r.protocol {
- case header.IPv6ProtocolNumber:
- return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */)
- default:
- r.expireLocked()
- return nil
+ if !n.neigh.removeEntry(addr) {
+ return tcpip.ErrBadAddress
}
+ return nil
}
-func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
- addr := r.addrWithPrefix()
-
- isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
-
- if isIPv6Unicast {
- n.mu.ndp.stopDuplicateAddressDetection(addr.Address)
-
- // If we are removing an address generated via SLAAC, cleanup
- // its SLAAC resources and notify the integrator.
- switch r.configType {
- case slaac:
- n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
- case slaacTemp:
- n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
- }
- }
-
- r.expireLocked()
-
- // At this point the endpoint is deleted.
-
- // If we are removing an IPv6 unicast address, leave the solicited-node
- // multicast address.
- //
- // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
- // multicast group may be left by user action.
- if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(addr.Address)
- if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
- }
+func (n *NIC) clearNeighbors() *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
}
+ n.neigh.clear()
return nil
}
-// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
- return n.removePermanentAddressLocked(addr)
-}
-
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- return n.joinGroupLocked(protocol, addr)
-}
-
-// joinGroupLocked adds a new endpoint for the given multicast address, if none
-// exists yet. Otherwise it just increments its count. n MUST be locked before
-// joinGroupLocked is called.
-func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
// TODO(b/143102137): When implementing MLD, make sure MLD packets are
// not sent unless a valid link-local address is available for use on n
// as an MLD packet's source address must be a link-local address as
// outlined in RFC 3810 section 5.
- id := NetworkEndpointID{addr}
- joins := n.mu.mcastJoins[id]
- if joins == 0 {
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- return tcpip.ErrUnknownProtocol
- }
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
- return err
- }
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return tcpip.ErrNotSupported
}
- n.mu.mcastJoins[id] = joins + 1
- return nil
+
+ gep, ok := ep.(GroupAddressableEndpoint)
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ _, err := gep.JoinGroup(addr)
+ return err
}
// leaveGroup decrements the count for the given multicast address, and when it
// reaches zero removes the endpoint for this address.
-func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- return n.leaveGroupLocked(addr, false /* force */)
-}
+func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
-// leaveGroupLocked decrements the count for the given multicast address, and
-// when it reaches zero removes the endpoint for this address. n MUST be locked
-// before leaveGroupLocked is called.
-//
-// If force is true, then the count for the multicast addres is ignored and the
-// endpoint will be removed immediately.
-func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
- id := NetworkEndpointID{addr}
- joins, ok := n.mu.mcastJoins[id]
+ gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
- // There are no joins with this address on this NIC.
- return tcpip.ErrBadLocalAddress
+ return tcpip.ErrNotSupported
}
- joins--
- if force || joins == 0 {
- // There are no outstanding joins or we are forced to leave, clean up.
- delete(n.mu.mcastJoins, id)
- return n.removePermanentAddressLocked(addr)
+ if _, err := gep.LeaveGroup(addr); err != nil {
+ return err
}
- n.mu.mcastJoins[id] = joins
return nil
}
// isInGroup returns true if n has joined the multicast group addr.
func (n *NIC) isInGroup(addr tcpip.Address) bool {
- n.mu.RLock()
- joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
- n.mu.RUnlock()
+ for _, ep := range n.networkEndpoints {
+ gep, ok := ep.(GroupAddressableEndpoint)
+ if !ok {
+ continue
+ }
- return joins != 0
+ if gep.IsInGroup(addr) {
+ return true
+ }
+ }
+
+ return false
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
+ r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ defer r.Release()
r.RemoteLinkAddress = remotelinkAddr
-
- ref.ep.HandlePacket(&r, pkt)
- ref.decRef()
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -1172,7 +502,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// the ownership of the items is not retained by the caller.
func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
- enabled := n.mu.enabled
+ enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
if !enabled {
n.mu.RUnlock()
@@ -1198,17 +528,15 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
local = n.linkEP.LinkAddress()
}
- // Are any packet sockets listening for this network protocol?
+ // Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet type sockets that may be listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -1223,37 +551,42 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
return
}
if hasTransportHdr {
+ pkt.TransportProtocolNumber = transProtoNum
// Parse the transport header if present.
if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
state.proto.Parse(pkt)
}
}
- src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
- if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
- // The source address is one of our own, so we never should have gotten a
- // packet like this unless handleLocal is false. Loopback also calls this
- // function even though the packets didn't come from the physical interface
- // so don't drop those.
- n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
- return
+ if n.stack.handleLocal && !n.IsLoopback() {
+ if r := n.getAddress(protocol, src); r != nil {
+ r.DecRef()
+
+ // The source address is one of our own, so we never should have gotten a
+ // packet like this unless handleLocal is false. Loopback also calls this
+ // function even though the packets didn't come from the physical interface
+ // so don't drop those.
+ n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
}
- // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
// Loopback traffic skips the prerouting chain.
- if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
+ if !n.IsLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
+ n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
return
}
}
- if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt)
+ if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil {
+ n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt)
return
}
@@ -1261,7 +594,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// packet and forward it to the NIC.
//
// TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding() {
+ if n.stack.Forwarding(protocol) {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
@@ -1269,25 +602,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// Found a NIC.
- n := r.ref.nic
- n.mu.RLock()
- ref, ok := n.mu.endpoints[NetworkEndpointID{dst}]
- ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
- n.mu.RUnlock()
- if ok {
- r.LocalLinkAddress = n.linkEP.LinkAddress()
- r.RemoteLinkAddress = remote
- r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- ref.ep.HandlePacket(&r, pkt)
- ref.decRef()
- r.Release()
- return
+ n := r.nic
+ if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
+ if n.isValidForOutgoing(addressEndpoint) {
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
+ r.RemoteAddress = src
+ // TODO(b/123449044): Update the source NIC as well.
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ addressEndpoint.DecRef()
+ r.Release()
+ return
+ }
+
+ addressEndpoint.DecRef()
}
// n doesn't have a destination endpoint.
// Send the packet out of n.
// TODO(b/128629022): move this logic to route.WritePacket.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
@@ -1311,26 +645,39 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
- // by having lower layers explicity write each header instead of just
- // pkt.Header.
-
- // pkt may have set its NetworkHeader and TransportHeader. If we're
- // forwarding, we'll have to copy them into pkt.Header.
- pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
- if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
- panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
- }
- if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
- panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
- }
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+ // pkt may have set its header and may not have enough headroom for link-layer
+ // header for the other link to prepend. Here we create a new packet to
+ // forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // WritePacket takes ownership of fwdPkt, calculate numBytes first.
+ numBytes := fwdPkt.Size()
- if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
+ if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return
}
@@ -1341,11 +688,11 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
- return
+ return TransportPacketProtocolUnreachable
}
transProto := state.proto
@@ -1355,52 +702,58 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- // TransportHeader is nil only when pkt is an ICMP packet or was reassembled
+ // TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
- if pkt.TransportHeader == nil {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their
- // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ if pkt.TransportHeader().View().IsEmpty() {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
// full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
- transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
- if !ok {
+ // ICMP packets may be longer, but until icmp.Parse is implemented, here
+ // we parse it using the minimum size.
+ if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
- return
+ // We consider a malformed transport packet handled because there is
+ // nothing the caller can do.
+ return TransportPacketHandled
}
- pkt.TransportHeader = transHeader
- } else {
- // This is either a bad packet or was re-assembled from fragments.
- transProto.Parse(pkt)
+ } else if !transProto.Parse(pkt) {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return TransportPacketHandled
}
}
- if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
- n.stack.stats.MalformedRcvdPackets.Increment()
- return
- }
-
- srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
- return
+ return TransportPacketHandled
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
- return
+ return TransportPacketHandled
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
if state.defaultHandler(r, id, pkt) {
- return
+ return TransportPacketHandled
}
}
- // We could not find an appropriate destination for this packet, so
- // deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
+ // We could not find an appropriate destination for this packet so
+ // give the protocol specific error handler a chance to handle it.
+ // If it doesn't handle it then we should do so.
+ switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res {
+ case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
+ return TransportPacketHandled
+ case UnknownDestinationPacketUnhandled:
+ return TransportPacketDestinationPortUnreachable
+ case UnknownDestinationPacketHandled:
+ return TransportPacketHandled
+ default:
+ panic(fmt.Sprintf("unrecognized result from HandleUnknownDestinationPacket = %d", res))
}
}
@@ -1433,137 +786,42 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
}
-// ID returns the identifier of n.
+// ID implements NetworkInterface.
func (n *NIC) ID() tcpip.NICID {
return n.id
}
-// Name returns the name of n.
+// Name implements NetworkInterface.
func (n *NIC) Name() string {
return n.name
}
-// Stack returns the instance of the Stack that owns this NIC.
-func (n *NIC) Stack() *Stack {
- return n.stack
-}
-
-// LinkEndpoint returns the link endpoint of n.
+// LinkEndpoint implements NetworkInterface.
func (n *NIC) LinkEndpoint() LinkEndpoint {
return n.linkEP
}
-// isAddrTentative returns true if addr is tentative on n.
-//
-// Note that if addr is not associated with n, then this function will return
-// false. It will only return true if the address is associated with the NIC
-// AND it is tentative.
-func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return false
+// nudConfigs gets the NUD configurations for n.
+func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
+ if n.neigh == nil {
+ return NUDConfigurations{}, tcpip.ErrNotSupported
}
-
- return ref.getKind() == permanentTentative
+ return n.neigh.config(), nil
}
-// dupTentativeAddrDetected attempts to inform n that a tentative addr is a
-// duplicate on a link.
+// setNUDConfigs sets the NUD configurations for n.
//
-// dupTentativeAddrDetected will remove the tentative address if it exists. If
-// the address was generated via SLAAC, an attempt will be made to generate a
-// new address.
-func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
- if !ok {
- return tcpip.ErrBadAddress
- }
-
- if ref.getKind() != permanentTentative {
- return tcpip.ErrInvalidEndpointState
- }
-
- // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
- // new address will be generated for it.
- if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil {
- return err
- }
-
- prefix := ref.addrWithPrefix().Subnet()
-
- switch ref.configType {
- case slaac:
- n.mu.ndp.regenerateSLAACAddr(prefix)
- case slaacTemp:
- // Do not reset the generation attempts counter for the prefix as the
- // temporary address is being regenerated in response to a DAD conflict.
- n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+// Note, if c contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
}
-
+ c.resetInvalidFields()
+ n.neigh.setConfig(c)
return nil
}
-// setNDPConfigs sets the NDP configurations for n.
-//
-// Note, if c contains invalid NDP configuration values, it will be fixed to
-// use default values for the erroneous values.
-func (n *NIC) setNDPConfigs(c NDPConfigurations) {
- c.validate()
-
- n.mu.Lock()
- n.mu.ndp.configs = c
- n.mu.Unlock()
-}
-
-// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
-func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- n.mu.ndp.handleRA(ip, ra)
-}
-
-type networkEndpointKind int32
-
-const (
- // A permanentTentative endpoint is a permanent address that is not yet
- // considered to be fully bound to an interface in the traditional
- // sense. That is, the address is associated with a NIC, but packets
- // destined to the address MUST NOT be accepted and MUST be silently
- // dropped, and the address MUST NOT be used as a source address for
- // outgoing packets. For IPv6, addresses will be of this kind until
- // NDP's Duplicate Address Detection has resolved, or be deleted if
- // the process results in detecting a duplicate address.
- permanentTentative networkEndpointKind = iota
-
- // A permanent endpoint is created by adding a permanent address (vs. a
- // temporary one) to the NIC. Its reference count is biased by 1 to avoid
- // removal when no route holds a reference to it. It is removed by explicitly
- // removing the permanent address from the NIC.
- permanent
-
- // An expired permanent endpoint is a permanent endpoint that had its address
- // removed from the NIC, and it is waiting to be removed once no more routes
- // hold a reference to it. This is achieved by decreasing its reference count
- // by 1. If its address is re-added before the endpoint is removed, its type
- // changes back to permanent and its reference count increases by 1 again.
- permanentExpired
-
- // A temporary endpoint is created for spoofing outgoing packets, or when in
- // promiscuous mode and accepting incoming packets that don't match any
- // permanent endpoint. Its reference count is not biased by 1 and the
- // endpoint is removed immediately when no more route holds a reference to
- // it. A temporary endpoint can be promoted to permanent if its address
- // is added permanently.
- temporary
-)
-
func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -1594,147 +852,12 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
}
-type networkEndpointConfigType int32
-
-const (
- // A statically configured endpoint is an address that was added by
- // some user-specified action (adding an explicit address, joining a
- // multicast group).
- static networkEndpointConfigType = iota
-
- // A SLAAC configured endpoint is an IPv6 endpoint that was added by
- // SLAAC as per RFC 4862 section 5.5.3.
- slaac
-
- // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by
- // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are
- // not expected to be valid (or preferred) forever; hence the term temporary.
- slaacTemp
-)
-
-type referencedNetworkEndpoint struct {
- ep NetworkEndpoint
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
-
- // linkCache is set if link address resolution is enabled for this
- // protocol. Set to nil otherwise.
- linkCache LinkAddressCache
-
- // refs is counting references held for this endpoint. When refs hits zero it
- // triggers the automatic removal of the endpoint from the NIC.
- refs int32
-
- // networkEndpointKind must only be accessed using {get,set}Kind().
- kind networkEndpointKind
-
- // configType is the method that was used to configure this endpoint.
- // This must never change except during endpoint creation and promotion to
- // permanent.
- configType networkEndpointConfigType
-
- // deprecated indicates whether or not the endpoint should be considered
- // deprecated. That is, when deprecated is true, other endpoints that are not
- // deprecated should be preferred.
- deprecated bool
-}
-
-func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
- return tcpip.AddressWithPrefix{
- Address: r.ep.ID().LocalAddress,
- PrefixLen: r.ep.PrefixLen(),
- }
-}
-
-func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
- return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
-}
-
-func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
- atomic.StoreInt32((*int32)(&r.kind), int32(kind))
-}
-
// isValidForOutgoing returns true if the endpoint can be used to send out a
// packet. It requires the endpoint to not be marked expired (i.e., its address)
// has been removed) unless the NIC is in spoofing mode, or temporary.
-func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
- r.nic.mu.RLock()
- defer r.nic.mu.RUnlock()
-
- return r.isValidForOutgoingRLocked()
-}
-
-// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
-// r.nic.mu to be read locked.
-func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- if !r.nic.mu.enabled {
- return false
- }
-
- return r.isAssignedRLocked(r.nic.mu.spoofing)
-}
-
-// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
-//
-// r.nic.mu must be read locked.
-func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
- switch r.getKind() {
- case permanentTentative:
- return false
- case permanentExpired:
- return spoofingOrPromiscuous
- default:
- return true
- }
-}
-
-// expireLocked decrements the reference count and marks the permanent endpoint
-// as expired.
-func (r *referencedNetworkEndpoint) expireLocked() {
- r.setKind(permanentExpired)
- r.decRefLocked()
-}
-
-// decRef decrements the ref count and cleans up the endpoint once it reaches
-// zero.
-func (r *referencedNetworkEndpoint) decRef() {
- if atomic.AddInt32(&r.refs, -1) == 0 {
- r.nic.removeEndpoint(r)
- }
-}
-
-// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked.
-func (r *referencedNetworkEndpoint) decRefLocked() {
- if atomic.AddInt32(&r.refs, -1) == 0 {
- r.nic.removeEndpointLocked(r)
- }
-}
-
-// incRef increments the ref count. It must only be called when the caller is
-// known to be holding a reference to the endpoint, otherwise tryIncRef should
-// be used.
-func (r *referencedNetworkEndpoint) incRef() {
- atomic.AddInt32(&r.refs, 1)
-}
-
-// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
-// not zero. That is, it will increment the count if the endpoint is still
-// alive, and do nothing if it has already been clean up.
-func (r *referencedNetworkEndpoint) tryIncRef() bool {
- for {
- v := atomic.LoadInt32(&r.refs)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
- return true
- }
- }
-}
-
-// stack returns the Stack instance that owns the underlying endpoint.
-func (r *referencedNetworkEndpoint) stack() *Stack {
- return r.nic.stack
+func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+ return n.Enabled() && ep.IsAssigned(spoofing)
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 31f865260..fdd49b77f 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -15,88 +15,40 @@
package stack
import (
- "math"
"testing"
- "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+var _ AddressableEndpoint = (*testIPv6Endpoint)(nil)
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
-// A LinkEndpoint that throws away outgoing packets.
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
//
-// We use this instead of the channel endpoint as the channel package depends on
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
// the stack package which this test lives in, causing a cyclic dependency.
-type testLinkEndpoint struct {
- dispatcher NetworkDispatcher
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (e *testLinkEndpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*testLinkEndpoint) MTU() uint32 {
- return math.MaxUint16
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
- return CapabilityResolutionRequired
-}
+type testIPv6Endpoint struct {
+ AddressableEndpointState
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*testLinkEndpoint) MaxHeaderLength() uint16 {
- return 0
-}
+ nicID tcpip.NICID
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
-// LinkAddress returns the link address of this endpoint.
-func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
- return ""
+ invalidatedRtr tcpip.Address
}
-// Wait implements LinkEndpoint.Wait.
-func (*testLinkEndpoint) Wait() {}
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+func (*testIPv6Endpoint) Enable() *tcpip.Error {
return nil
}
-// WritePackets implements LinkEndpoint.WritePackets.
-func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // Our tests don't use this so we don't support it.
- return 0, tcpip.ErrNotSupported
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- // Our tests don't use this so we don't support it.
- return tcpip.ErrNotSupported
+func (*testIPv6Endpoint) Enabled() bool {
+ return true
}
-var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
-
-// An IPv6 NetworkEndpoint that throws away outgoing packets.
-//
-// We use this instead of ipv6.endpoint because the ipv6 package depends on
-// the stack package which this test lives in, causing a cyclic dependency.
-type testIPv6Endpoint struct {
- nicID tcpip.NICID
- id NetworkEndpointID
- prefixLen int
- linkEP LinkEndpoint
- protocol *testIPv6Protocol
-}
+func (*testIPv6Endpoint) Disable() {}
// DefaultTTL implements NetworkEndpoint.DefaultTTL.
func (*testIPv6Endpoint) DefaultTTL() uint8 {
@@ -108,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 {
return e.linkEP.MTU() - header.IPv6MinimumSize
}
-// Capabilities implements NetworkEndpoint.Capabilities.
-func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
@@ -136,33 +83,24 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
return tcpip.ErrNotSupported
}
-// ID implements NetworkEndpoint.ID.
-func (e *testIPv6Endpoint) ID() *NetworkEndpointID {
- return &e.id
-}
-
-// PrefixLen implements NetworkEndpoint.PrefixLen.
-func (e *testIPv6Endpoint) PrefixLen() int {
- return e.prefixLen
-}
-
-// NICID implements NetworkEndpoint.NICID.
-func (e *testIPv6Endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
// HandlePacket implements NetworkEndpoint.HandlePacket.
func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
}
// Close implements NetworkEndpoint.Close.
-func (*testIPv6Endpoint) Close() {}
+func (e *testIPv6Endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
+func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.invalidatedRtr = rtr
+}
+
var _ NetworkProtocol = (*testIPv6Protocol)(nil)
// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
@@ -194,23 +132,23 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
}
// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
- return &testIPv6Endpoint{
- nicID: nicID,
- id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
- protocol: p,
- }, nil
+func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
+ e := &testIPv6Endpoint{
+ nicID: nic.ID(),
+ linkEP: nic.LinkEndpoint(),
+ protocol: p,
+ }
+ e.AddressableEndpointState.Init(e)
+ return e
}
// SetOption implements NetworkProtocol.SetOption.
-func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return nil
}
// Option implements NetworkProtocol.Option.
-func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return nil
}
@@ -233,7 +171,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
return nil
}
@@ -245,38 +183,6 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd
return "", false
}
-// Test the race condition where a NIC is removed and an RS timer fires at the
-// same time.
-func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
- const (
- nicID = 1
-
- maxRtrSolicitations = 5
- )
-
- e := testLinkEndpoint{}
- s := New(Options{
- NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
- NDPConfigs: NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: minimumRtrSolicitationInterval,
- },
- })
-
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
- }
-
- s.mu.Lock()
- // Wait for the router solicitation timer to fire and block trying to obtain
- // the stack lock when doing link address resolution.
- time.Sleep(minimumRtrSolicitationInterval * 2)
- if err := s.removeNICLocked(nicID); err != nil {
- t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
- }
- s.mu.Unlock()
-}
-
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
@@ -301,7 +207,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(),
+ }))
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
new file mode 100644
index 000000000..e1ec15487
--- /dev/null
+++ b/pkg/tcpip/stack/nud.go
@@ -0,0 +1,466 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // defaultBaseReachableTime is the default base duration for computing the
+ // random reachable time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of a uniformly distributed random value between the minimum and
+ // maximum random factors, multiplied by the base reachable time. Using a
+ // random component eliminates the possibility that Neighbor Unreachability
+ // Detection messages will synchronize with each other.
+ //
+ // Default taken from REACHABLE_TIME of RFC 4861 section 10.
+ defaultBaseReachableTime = 30 * time.Second
+
+ // minimumBaseReachableTime is the minimum base duration for computing the
+ // random reachable time.
+ //
+ // Minimum = 1ms
+ minimumBaseReachableTime = time.Millisecond
+
+ // defaultMinRandomFactor is the default minimum value of the random factor
+ // used for computing reachable time.
+ //
+ // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10.
+ defaultMinRandomFactor = 0.5
+
+ // defaultMaxRandomFactor is the default maximum value of the random factor
+ // used for computing reachable time.
+ //
+ // The default value depends on the value of MinRandomFactor.
+ // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10,
+ // the value from the RFC will be used; otherwise, the default is
+ // MinRandomFactor multiplied by three.
+ defaultMaxRandomFactor = 1.5
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
+ // defaultDelayFirstProbeTime is the default duration to wait for a
+ // non-Neighbor-Discovery related protocol to reconfirm reachability after
+ // entering the DELAY state. After this time, a reachability probe will be
+ // sent and the entry will transition to the PROBE state.
+ //
+ // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10.
+ defaultDelayFirstProbeTime = 5 * time.Second
+
+ // defaultMaxMulticastProbes is the default number of reachabililty probes
+ // to send before concluding negative reachability and deleting the neighbor
+ // entry from the INCOMPLETE state.
+ //
+ // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10.
+ defaultMaxMulticastProbes = 3
+
+ // defaultMaxUnicastProbes is the default number of reachability probes to
+ // send before concluding retransmission from within the PROBE state should
+ // cease and the entry SHOULD be deleted.
+ //
+ // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10.
+ defaultMaxUnicastProbes = 3
+
+ // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD
+ // delay sending a response for a random time between 0 and this time, if the
+ // target address is an anycast address.
+ //
+ // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10.
+ defaultMaxAnycastDelayTime = time.Second
+
+ // defaultMaxReachbilityConfirmations is the default amount of unsolicited
+ // reachability confirmation messages a node MAY send to all-node multicast
+ // address when it determines its link-layer address has changed.
+ //
+ // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
+ defaultMaxReachbilityConfirmations = 3
+
+ // defaultUnreachableTime is the default duration for how long an entry will
+ // remain in the FAILED state before being removed from the neighbor cache.
+ //
+ // Note, there is no equivalent protocol constant defined in RFC 4861. It
+ // leaves the specifics of any garbage collection mechanism up to the
+ // implementation.
+ defaultUnreachableTime = 5 * time.Second
+)
+
+// NUDDispatcher is the interface integrators of netstack must implement to
+// receive and handle NUD related events.
+type NUDDispatcher interface {
+ // OnNeighborAdded will be called when a new entry is added to a NIC's (with
+ // ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
+ // neighbor table changes state and/or link address.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborRemoved will be called when an entry is removed from a NIC's
+ // (with ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+}
+
+// ReachabilityConfirmationFlags describes the flags used within a reachability
+// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP,
+// respectively).
+type ReachabilityConfirmationFlags struct {
+ // Solicited indicates that the advertisement was sent in response to a
+ // reachability probe.
+ Solicited bool
+
+ // Override indicates that the reachability confirmation should override an
+ // existing neighbor cache entry and update the cached link-layer address.
+ // When Override is not set the confirmation will not update a cached
+ // link-layer address, but will update an existing neighbor cache entry for
+ // which no link-layer address is known.
+ Override bool
+
+ // IsRouter indicates that the sender is a router.
+ IsRouter bool
+}
+
+// NUDHandler communicates external events to the Neighbor Unreachability
+// Detection state machine, which is implemented per-interface. This is used by
+// network endpoints to inform the Neighbor Cache of probes and confirmations.
+type NUDHandler interface {
+ // HandleProbe processes an incoming neighbor probe (e.g. ARP request or
+ // Neighbor Solicitation for ARP or NDP, respectively). Validation of the
+ // probe needs to be performed before calling this function since the
+ // Neighbor Cache doesn't have access to view the NIC's assigned addresses.
+ HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
+
+ // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
+ // reply or Neighbor Advertisement for ARP or NDP, respectively).
+ HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags)
+
+ // HandleUpperLevelConfirmation processes an incoming upper-level protocol
+ // (e.g. TCP acknowledgements) reachability confirmation.
+ HandleUpperLevelConfirmation(addr tcpip.Address)
+}
+
+// NUDConfigurations is the NUD configurations for the netstack. This is used
+// by the neighbor cache to operate the NUD state machine on each device in the
+// local network.
+type NUDConfigurations struct {
+ // BaseReachableTime is the base duration for computing the random reachable
+ // time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of uniformly distributed random value between minRandomFactor and
+ // maxRandomFactor multiplied by baseReachableTime. Using a random component
+ // eliminates the possibility that Neighbor Unreachability Detection messages
+ // will synchronize with each other.
+ //
+ // After this time, a neighbor entry will transition from REACHABLE to STALE
+ // state.
+ //
+ // Must be greater than 0.
+ BaseReachableTime time.Duration
+
+ // LearnBaseReachableTime enables learning BaseReachableTime during runtime
+ // from the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option.
+ LearnBaseReachableTime bool
+
+ // MinRandomFactor is the minimum value of the random factor used for
+ // computing reachable time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be greater than 0.
+ MinRandomFactor float32
+
+ // MaxRandomFactor is the maximum value of the random factor used for
+ // computing reachabile time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be great than or equal to MinRandomFactor.
+ MaxRandomFactor float32
+
+ // RetransmitTimer is the duration between retransmission of reachability
+ // probes in the PROBE state.
+ RetransmitTimer time.Duration
+
+ // LearnRetransmitTimer enables learning RetransmitTimer during runtime from
+ // the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option.
+ LearnRetransmitTimer bool
+
+ // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery
+ // related protocol to reconfirm reachability after entering the DELAY state.
+ // After this time, a reachability probe will be sent and the entry will
+ // transition to the PROBE state.
+ //
+ // Must be greater than 0.
+ DelayFirstProbeTime time.Duration
+
+ // MaxMulticastProbes is the number of reachability probes to send before
+ // concluding negative reachability and deleting the neighbor entry from the
+ // INCOMPLETE state.
+ //
+ // Must be greater than 0.
+ MaxMulticastProbes uint32
+
+ // MaxUnicastProbes is the number of reachability probes to send before
+ // concluding retransmission from within the PROBE state should cease and
+ // entry SHOULD be deleted.
+ //
+ // Must be greater than 0.
+ MaxUnicastProbes uint32
+
+ // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a
+ // response for a random time between 0 and this time, if the target address
+ // is an anycast address.
+ //
+ // TODO(gvisor.dev/issue/2242): Use this option when sending solicited
+ // neighbor confirmations to anycast addresses and proxying neighbor
+ // confirmations.
+ MaxAnycastDelayTime time.Duration
+
+ // MaxReachabilityConfirmations is the number of unsolicited reachability
+ // confirmation messages a node MAY send to all-node multicast address when
+ // it determines its link-layer address has changed.
+ //
+ // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
+ // configuration option is necessary.
+ MaxReachabilityConfirmations uint32
+
+ // UnreachableTime describes how long an entry will remain in the FAILED
+ // state before being removed from the neighbor cache.
+ UnreachableTime time.Duration
+}
+
+// DefaultNUDConfigurations returns a NUDConfigurations populated with default
+// values defined by RFC 4861 section 10.
+func DefaultNUDConfigurations() NUDConfigurations {
+ return NUDConfigurations{
+ BaseReachableTime: defaultBaseReachableTime,
+ LearnBaseReachableTime: true,
+ MinRandomFactor: defaultMinRandomFactor,
+ MaxRandomFactor: defaultMaxRandomFactor,
+ RetransmitTimer: defaultRetransmitTimer,
+ LearnRetransmitTimer: true,
+ DelayFirstProbeTime: defaultDelayFirstProbeTime,
+ MaxMulticastProbes: defaultMaxMulticastProbes,
+ MaxUnicastProbes: defaultMaxUnicastProbes,
+ MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
+ MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
+ UnreachableTime: defaultUnreachableTime,
+ }
+}
+
+// resetInvalidFields modifies an invalid NDPConfigurations with valid values.
+// If invalid values are present in c, the corresponding default values will be
+// used instead. This is needed to check, and conditionally fix, user-specified
+// NUDConfigurations.
+func (c *NUDConfigurations) resetInvalidFields() {
+ if c.BaseReachableTime < minimumBaseReachableTime {
+ c.BaseReachableTime = defaultBaseReachableTime
+ }
+ if c.MinRandomFactor <= 0 {
+ c.MinRandomFactor = defaultMinRandomFactor
+ }
+ if c.MaxRandomFactor < c.MinRandomFactor {
+ c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor)
+ }
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+ if c.DelayFirstProbeTime == 0 {
+ c.DelayFirstProbeTime = defaultDelayFirstProbeTime
+ }
+ if c.MaxMulticastProbes == 0 {
+ c.MaxMulticastProbes = defaultMaxMulticastProbes
+ }
+ if c.MaxUnicastProbes == 0 {
+ c.MaxUnicastProbes = defaultMaxUnicastProbes
+ }
+ if c.UnreachableTime == 0 {
+ c.UnreachableTime = defaultUnreachableTime
+ }
+}
+
+// calcMaxRandomFactor calculates the maximum value of the random factor used
+// for computing reachable time. This function is necessary for when the
+// default specified in RFC 4861 section 10 is less than the current
+// MinRandomFactor.
+//
+// Assumes minRandomFactor is positive since validation of the minimum value
+// should come before the validation of the maximum.
+func calcMaxRandomFactor(minRandomFactor float32) float32 {
+ if minRandomFactor > defaultMaxRandomFactor {
+ return minRandomFactor * 3
+ }
+ return defaultMaxRandomFactor
+}
+
+// A Rand is a source of random numbers.
+type Rand interface {
+ // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0).
+ Float32() float32
+}
+
+// NUDState stores states needed for calculating reachable time.
+type NUDState struct {
+ rng Rand
+
+ // mu protects the fields below.
+ //
+ // It is necessary for NUDState to handle its own locking since neighbor
+ // entries may access the NUD state from within the goroutine spawned by
+ // time.AfterFunc(). This goroutine may run concurrently with the main
+ // process for controlling the neighbor cache and would otherwise introduce
+ // race conditions if NUDState was not locked properly.
+ mu sync.RWMutex
+
+ config NUDConfigurations
+
+ // reachableTime is the duration to wait for a REACHABLE entry to
+ // transition into STALE after inactivity. This value is calculated with
+ // the algorithm defined in RFC 4861 section 6.3.2.
+ reachableTime time.Duration
+
+ expiration time.Time
+ prevBaseReachableTime time.Duration
+ prevMinRandomFactor float32
+ prevMaxRandomFactor float32
+}
+
+// NewNUDState returns new NUDState using c as configuration and the specified
+// random number generator for use in recomputing ReachableTime.
+func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
+ s := &NUDState{
+ rng: rng,
+ }
+ s.config = c
+ return s
+}
+
+// Config returns the NUD configuration.
+func (s *NUDState) Config() NUDConfigurations {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.config
+}
+
+// SetConfig replaces the existing NUD configurations with c.
+func (s *NUDState) SetConfig(c NUDConfigurations) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.config = c
+}
+
+// ReachableTime returns the duration to wait for a REACHABLE entry to
+// transition into STALE after inactivity. This value is recalculated for new
+// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the
+// algorithm defined in RFC 4861 section 6.3.2.
+func (s *NUDState) ReachableTime() time.Duration {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if time.Now().After(s.expiration) ||
+ s.config.BaseReachableTime != s.prevBaseReachableTime ||
+ s.config.MinRandomFactor != s.prevMinRandomFactor ||
+ s.config.MaxRandomFactor != s.prevMaxRandomFactor {
+ return s.recomputeReachableTimeLocked()
+ }
+ return s.reachableTime
+}
+
+// recomputeReachableTimeLocked forces a recalculation of ReachableTime using
+// the algorithm defined in RFC 4861 section 6.3.2.
+//
+// This SHOULD automatically be invoked during certain situations, as per
+// RFC 4861 section 6.3.4:
+//
+// If the received Reachable Time value is non-zero, the host SHOULD set its
+// BaseReachableTime variable to the received value. If the new value
+// differs from the previous value, the host SHOULD re-compute a new random
+// ReachableTime value. ReachableTime is computed as a uniformly
+// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR
+// times the BaseReachableTime. Using a random component eliminates the
+// possibility that Neighbor Unreachability Detection messages will
+// synchronize with each other.
+//
+// In most cases, the advertised Reachable Time value will be the same in
+// consecutive Router Advertisements, and a host's BaseReachableTime rarely
+// changes. In such cases, an implementation SHOULD ensure that a new
+// random value gets re-computed at least once every few hours.
+//
+// s.mu MUST be locked for writing.
+func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+ s.prevBaseReachableTime = s.config.BaseReachableTime
+ s.prevMinRandomFactor = s.config.MinRandomFactor
+ s.prevMaxRandomFactor = s.config.MaxRandomFactor
+
+ randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor)
+
+ // Check for overflow, given that minRandomFactor and maxRandomFactor are
+ // guaranteed to be positive numbers.
+ if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) {
+ s.reachableTime = time.Duration(math.MaxInt64)
+ } else if randomFactor == 1 {
+ // Avoid loss of precision when a large base reachable time is used.
+ s.reachableTime = s.config.BaseReachableTime
+ } else {
+ reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor)
+ s.reachableTime = time.Duration(reachableTime)
+ }
+
+ s.expiration = time.Now().Add(2 * time.Hour)
+ return s.reachableTime
+}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
new file mode 100644
index 000000000..8cffb9fc6
--- /dev/null
+++ b/pkg/tcpip/stack/nud_test.go
@@ -0,0 +1,807 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ defaultBaseReachableTime = 30 * time.Second
+ minimumBaseReachableTime = time.Millisecond
+ defaultMinRandomFactor = 0.5
+ defaultMaxRandomFactor = 1.5
+ defaultRetransmitTimer = time.Second
+ minimumRetransmitTimer = time.Millisecond
+ defaultDelayFirstProbeTime = 5 * time.Second
+ defaultMaxMulticastProbes = 3
+ defaultMaxUnicastProbes = 3
+ defaultMaxAnycastDelayTime = time.Second
+ defaultMaxReachbilityConfirmations = 3
+ defaultUnreachableTime = 5 * time.Second
+
+ defaultFakeRandomNum = 0.5
+)
+
+// fakeRand is a deterministic random number generator.
+type fakeRand struct {
+ num float32
+}
+
+var _ stack.Rand = (*fakeRand)(nil)
+
+func (f *fakeRand) Float32() float32 {
+ return f.num
+}
+
+// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NUD configurations using an invalid NICID.
+func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ UseNeighborCache: true,
+ })
+
+ // No NIC with ID 1 yet.
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to retrieve NUD configurations when the
+// stack doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to set NUD configurations when the stack
+// doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestDefaultNUDConfigurationIsValid verifies that calling
+// resetInvalidFields() on the result of DefaultNUDConfigurations() does not
+// change anything. DefaultNUDConfigurations() should return a valid
+// NUDConfigurations.
+func TestDefaultNUDConfigurations(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ c, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got, want := c, stack.DefaultNUDConfigurations(); got != want {
+ t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want)
+ }
+}
+
+func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ baseReachableTime: 0,
+ want: defaultBaseReachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ baseReachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ {
+ name: "MoreThanDefaultBaseReachableTime",
+ baseReachableTime: 2 * defaultBaseReachableTime,
+ want: 2 * defaultBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = test.baseReachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.BaseReachableTime; got != test.want {
+ t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: -1,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: 0,
+ want: defaultMinRandomFactor,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ minRandomFactor: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MinRandomFactor; got != test.want {
+ t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ maxRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: -1,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: 0,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "LessThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 0.99,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault",
+ minRandomFactor: defaultMaxRandomFactor * 2,
+ maxRandomFactor: defaultMaxRandomFactor,
+ want: defaultMaxRandomFactor * 6,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 1.1,
+ want: defaultMinRandomFactor * 1.1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxRandomFactor; got != test.want {
+ t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
+ tests := []struct {
+ name string
+ retransmitTimer time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ retransmitTimer: 0,
+ want: defaultRetransmitTimer,
+ },
+ {
+ name: "LessThanMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer - time.Nanosecond,
+ want: defaultRetransmitTimer,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer,
+ want: minimumBaseReachableTime,
+ },
+ {
+ name: "LargetThanMinimumRetransmitTimer",
+ retransmitTimer: 2 * minimumBaseReachableTime,
+ want: 2 * minimumBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.RetransmitTimer = test.retransmitTimer
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.RetransmitTimer; got != test.want {
+ t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
+ tests := []struct {
+ name string
+ delayFirstProbeTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ delayFirstProbeTime: 0,
+ want: defaultDelayFirstProbeTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ delayFirstProbeTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.DelayFirstProbeTime = test.delayFirstProbeTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.DelayFirstProbeTime; got != test.want {
+ t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxMulticastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxMulticastProbes: 0,
+ want: defaultMaxMulticastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxMulticastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxMulticastProbes = test.maxMulticastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxMulticastProbes; got != test.want {
+ t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxUnicastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxUnicastProbes: 0,
+ want: defaultMaxUnicastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxUnicastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxUnicastProbes = test.maxUnicastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxUnicastProbes; got != test.want {
+ t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsUnreachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ unreachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ unreachableTime: 0,
+ want: defaultUnreachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ unreachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.UnreachableTime = test.unreachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ NUDConfigs: c,
+ UseNeighborCache: true,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.UnreachableTime; got != test.want {
+ t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+// TestNUDStateReachableTime verifies the correctness of the ReachableTime
+// computation.
+func TestNUDStateReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "AllZeros",
+ baseReachableTime: 0,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMaxRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMinRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 1,
+ want: time.Duration(defaultFakeRandomNum * float32(time.Second)),
+ },
+ {
+ name: "FractionalRandomFactor",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 0.001,
+ maxRandomFactor: 0.002,
+ want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)),
+ },
+ {
+ name: "MinAndMaxRandomFactorsEqual",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Second,
+ },
+ {
+ name: "MinAndMaxRandomFactorsDifferent",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 2,
+ want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)),
+ },
+ {
+ name: "MaxInt64",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "Overflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1.5,
+ maxRandomFactor: 1.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "DoubleOverflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 2.5,
+ maxRandomFactor: 2.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.NUDConfigurations{
+ BaseReachableTime: test.baseReachableTime,
+ MinRandomFactor: test.minRandomFactor,
+ MaxRandomFactor: test.maxRandomFactor,
+ }
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
+
+// TestNUDStateRecomputeReachableTime exercises the ReachableTime function
+// twice to verify recomputation of reachable time when the min random factor,
+// max random factor, or base reachable time changes.
+func TestNUDStateRecomputeReachableTime(t *testing.T) {
+ const defaultBase = time.Second
+ const defaultMin = 2.0 * defaultMaxRandomFactor
+ const defaultMax = 3.0 * defaultMaxRandomFactor
+
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "BaseReachableTime",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMax,
+ want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ {
+ name: "MinRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMax,
+ maxRandomFactor: defaultMax,
+ want: time.Duration(defaultMax * float32(defaultBase)),
+ },
+ {
+ name: "MaxRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMin,
+ want: time.Duration(defaultMin * float32(defaultBase)),
+ },
+ {
+ name: "BothRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)),
+ },
+ {
+ name: "BaseReachableTimeAndBothRandomFactors",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = defaultBase
+ c.MinRandomFactor = defaultMin
+ c.MaxRandomFactor = defaultMax
+
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ old := s.ReachableTime()
+
+ if got, want := s.ReachableTime(), old; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Check for recomputation when changing the min random factor, the max
+ // random factor, the base reachability time, or any permutation of those
+ // three options.
+ c.BaseReachableTime = test.baseReachableTime
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+ s.SetConfig(c)
+
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Verify that ReachableTime isn't recomputed when none of the
+ // configuration options change. The random factor is changed so that if
+ // a recompution were to occur, ReachableTime would change.
+ rng.num = defaultFakeRandomNum / 2.0
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 1b5da6017..105583c49 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -14,52 +14,83 @@
package stack
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
+type headerType int
+
+const (
+ linkHeader headerType = iota
+ networkHeader
+ transportHeader
+ numHeaderType
+)
+
+// PacketBufferOptions specifies options for PacketBuffer creation.
+type PacketBufferOptions struct {
+ // ReserveHeaderBytes is the number of bytes to reserve for headers. Total
+ // number of bytes pushed onto the headers must not exceed this value.
+ ReserveHeaderBytes int
+
+ // Data is the initial unparsed data for the new packet. If set, it will be
+ // owned by the new packet.
+ Data buffer.VectorisedView
+}
+
// A PacketBuffer contains all the data of a network packet.
//
// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
-// multiple endpoints. Clone() should be called in such cases so that
-// modifications to the Data field do not affect other copies.
+// multiple endpoints.
+//
+// The whole packet is expected to be a series of bytes in the following order:
+// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be
+// empty. Use of PacketBuffer in any other order is unsupported.
+//
+// PacketBuffer must be created with NewPacketBuffer.
type PacketBuffer struct {
- _ noCopy
+ _ sync.NoCopy
// PacketBufferEntry is used to build an intrusive list of
// PacketBuffers.
PacketBufferEntry
- // Data holds the payload of the packet. For inbound packets, it also
- // holds the headers, which are consumed as the packet moves up the
- // stack. Headers are guaranteed not to be split across views.
+ // Data holds the payload of the packet.
+ //
+ // For inbound packets, Data is initially the whole packet. Then gets moved to
+ // headers via PacketHeader.Consume, when the packet is being parsed.
+ //
+ // For outbound packets, Data is the innermost layer, defined by the protocol.
+ // Headers are pushed in front of it via PacketHeader.Push.
//
- // The bytes backing Data are immutable, but Data itself may be trimmed
- // or otherwise modified.
+ // The bytes backing Data are immutable, a.k.a. users shouldn't write to its
+ // backing storage.
Data buffer.VectorisedView
- // Header holds the headers of outbound packets. As a packet is passed
- // down the stack, each layer adds to Header. Note that forwarded
- // packets don't populate Headers on their way out -- their headers and
- // payload are never parsed out and remain in Data.
- //
- // TODO(gvisor.dev/issue/170): Forwarded packets don't currently
- // populate Header, but should. This will be doable once early parsing
- // (https://github.com/google/gvisor/pull/1995) is supported.
- Header buffer.Prependable
+ // headers stores metadata about each header.
+ headers [numHeaderType]headerInfo
- // These fields are used by both inbound and outbound packets. They
- // typically overlap with the Data and Header fields.
- //
- // The bytes backing these views are immutable. Each field may be nil
- // if either it has not been set yet or no such header exists (e.g.
- // packets sent via loopback may not have a link header).
+ // header is the internal storage for outbound packets. Headers will be pushed
+ // (prepended) on this storage as the packet is being constructed.
//
- // These fields may be Views into other slices (either Data or Header).
- // SR dosen't support this, so deep copies are necessary in some cases.
- LinkHeader buffer.View
- NetworkHeader buffer.View
- TransportHeader buffer.View
+ // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and
+ // data are held in the same underlying buffer storage.
+ header buffer.Prependable
+
+ // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty()
+ // returns false.
+ // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
+ // numbers in registration APIs that take a PacketBuffer.
+ NetworkProtocolNumber tcpip.NetworkProtocolNumber
+
+ // TransportProtocol is only valid if it is non zero.
+ // TODO(gvisor.dev/issue/3810): This and the network protocol number should
+ // be moved into the headerinfo. This should resolve the validity issue.
+ TransportProtocolNumber tcpip.TransportProtocolNumber
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
@@ -71,45 +102,220 @@ type PacketBuffer struct {
// The following fields are only set by the qdisc layer when the packet
// is added to a queue.
- EgressRoute *Route
- GSOOptions *GSO
- NetworkProtocolNumber tcpip.NetworkProtocolNumber
+ EgressRoute *Route
+ GSOOptions *GSO
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
-// Clone makes a copy of pk. It clones the Data field, which creates a new
-// VectorisedView but does not deep copy the underlying bytes.
-//
-// Clone also does not deep copy any of its other fields.
+// NewPacketBuffer creates a new PacketBuffer with opts.
+func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
+ pk := &PacketBuffer{
+ Data: opts.Data,
+ }
+ if opts.ReserveHeaderBytes != 0 {
+ pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes)
+ }
+ return pk
+}
+
+// ReservedHeaderBytes returns the number of bytes initially reserved for
+// headers.
+func (pk *PacketBuffer) ReservedHeaderBytes() int {
+ return pk.header.UsedLength() + pk.header.AvailableLength()
+}
+
+// AvailableHeaderBytes returns the number of bytes currently available for
+// headers. This is relevant to PacketHeader.Push method only.
+func (pk *PacketBuffer) AvailableHeaderBytes() int {
+ return pk.header.AvailableLength()
+}
+
+// LinkHeader returns the handle to link-layer header.
+func (pk *PacketBuffer) LinkHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: linkHeader,
+ }
+}
+
+// NetworkHeader returns the handle to network-layer header.
+func (pk *PacketBuffer) NetworkHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: networkHeader,
+ }
+}
+
+// TransportHeader returns the handle to transport-layer header.
+func (pk *PacketBuffer) TransportHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: transportHeader,
+ }
+}
+
+// HeaderSize returns the total size of all headers in bytes.
+func (pk *PacketBuffer) HeaderSize() int {
+ // Note for inbound packets (Consume called), headers are not stored in
+ // pk.header. Thus, calculation of size of each header is needed.
+ var size int
+ for i := range pk.headers {
+ size += len(pk.headers[i].buf)
+ }
+ return size
+}
+
+// Size returns the size of packet in bytes.
+func (pk *PacketBuffer) Size() int {
+ return pk.HeaderSize() + pk.Data.Size()
+}
+
+// Views returns the underlying storage of the whole packet.
+func (pk *PacketBuffer) Views() []buffer.View {
+ // Optimization for outbound packets that headers are in pk.header.
+ useHeader := true
+ for i := range pk.headers {
+ if !canUseHeader(&pk.headers[i]) {
+ useHeader = false
+ break
+ }
+ }
+
+ dataViews := pk.Data.Views()
+
+ var vs []buffer.View
+ if useHeader {
+ vs = make([]buffer.View, 0, 1+len(dataViews))
+ vs = append(vs, pk.header.View())
+ } else {
+ vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews))
+ for i := range pk.headers {
+ if v := pk.headers[i].buf; len(v) > 0 {
+ vs = append(vs, v)
+ }
+ }
+ }
+ return append(vs, dataViews...)
+}
+
+func canUseHeader(h *headerInfo) bool {
+ // h.offset will be negative if the header was pushed in to prependable
+ // portion, or doesn't matter when it's empty.
+ return len(h.buf) == 0 || h.offset < 0
+}
+
+func (pk *PacketBuffer) push(typ headerType, size int) buffer.View {
+ h := &pk.headers[typ]
+ if h.buf != nil {
+ panic(fmt.Sprintf("push must not be called twice: type %s", typ))
+ }
+ h.buf = buffer.View(pk.header.Prepend(size))
+ h.offset = -pk.header.UsedLength()
+ return h.buf
+}
+
+func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) {
+ h := &pk.headers[typ]
+ if h.buf != nil {
+ panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
+ }
+ v, ok := pk.Data.PullUp(size)
+ if !ok {
+ return
+ }
+ pk.Data.TrimFront(size)
+ h.buf = v
+ return h.buf, true
+}
+
+// Clone makes a shallow copy of pk.
//
-// FIXME(b/153685824): Data gets copied but not other header references.
+// Clone should be called in such cases so that no modifications is done to
+// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- return &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- Header: pk.Header,
- LinkHeader: pk.LinkHeader,
- NetworkHeader: pk.NetworkHeader,
- TransportHeader: pk.TransportHeader,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
+ newPk := &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
}
+ return newPk
}
-// noCopy may be embedded into structs which must not be copied
-// after the first use.
+// Network returns the network header as a header.Network.
//
-// See https://golang.org/issues/8005#issuecomment-190753527
-// for details.
-type noCopy struct{}
+// Network should only be called when NetworkHeader has been set.
+func (pk *PacketBuffer) Network() header.Network {
+ switch netProto := pk.NetworkProtocolNumber; netProto {
+ case header.IPv4ProtocolNumber:
+ return header.IPv4(pk.NetworkHeader().View())
+ case header.IPv6ProtocolNumber:
+ return header.IPv6(pk.NetworkHeader().View())
+ default:
+ panic(fmt.Sprintf("unknown network protocol number %d", netProto))
+ }
+}
+
+// headerInfo stores metadata about a header in a packet.
+type headerInfo struct {
+ // buf is the memorized slice for both prepended and consumed header.
+ // When header is prepended, buf serves as memorized value, which is a slice
+ // of pk.header. When header is consumed, buf is the slice pulled out from
+ // pk.Data, which is the only place to hold this header.
+ buf buffer.View
+
+ // offset will be a negative number denoting the offset where this header is
+ // from the end of pk.header, if it is prepended. Otherwise, zero.
+ offset int
+}
+
+// PacketHeader is a handle object to a header in the underlying packet.
+type PacketHeader struct {
+ pk *PacketBuffer
+ typ headerType
+}
+
+// View returns the underlying storage of h.
+func (h PacketHeader) View() buffer.View {
+ return h.pk.headers[h.typ].buf
+}
+
+// Push pushes size bytes in the front of its residing packet, and returns the
+// backing storage. Callers may only call one of Push or Consume once on each
+// header in the lifetime of the underlying packet.
+func (h PacketHeader) Push(size int) buffer.View {
+ return h.pk.push(h.typ, size)
+}
-// Lock is a no-op used by -copylocks checker from `go vet`.
-func (*noCopy) Lock() {}
-func (*noCopy) Unlock() {}
+// Consume moves the first size bytes of the unparsed data portion in the packet
+// to h, and returns the backing storage. In the case of data is shorter than
+// size, consumed will be false, and the state of h will not be affected.
+// Callers may only call one of Push or Consume once on each header in the
+// lifetime of the underlying packet.
+func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
+ return h.pk.consume(h.typ, size)
+}
+
+// PayloadSince returns packet payload starting from and including a particular
+// header. This method isn't optimized and should be used in test only.
+func PayloadSince(h PacketHeader) buffer.View {
+ var v buffer.View
+ for _, hinfo := range h.pk.headers[h.typ:] {
+ v = append(v, hinfo.buf...)
+ }
+ return append(v, h.pk.Data.ToView()...)
+}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
new file mode 100644
index 000000000..c6fa8da5f
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -0,0 +1,397 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func TestPacketHeaderPush(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reserved int
+ link []byte
+ network []byte
+ transport []byte
+ data []byte
+ }{
+ {
+ name: "construct empty packet",
+ },
+ {
+ name: "construct link header only packet",
+ reserved: 60,
+ link: makeView(10),
+ },
+ {
+ name: "construct link and network header only packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ },
+ {
+ name: "construct header only packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ transport: makeView(30),
+ },
+ {
+ name: "construct data only packet",
+ data: makeView(40),
+ },
+ {
+ name: "construct L3 packet",
+ reserved: 60,
+ network: makeView(20),
+ transport: makeView(30),
+ data: makeView(40),
+ },
+ {
+ name: "construct L2 packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ transport: makeView(30),
+ data: makeView(40),
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: test.reserved,
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
+ })
+
+ allHdrSize := len(test.link) + len(test.network) + len(test.transport)
+
+ // Check the initial values for packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ ReserveHeaderBytes: test.reserved,
+ Data: buffer.View(test.data).ToVectorisedView(),
+ })
+
+ // Push headers.
+ if v := test.transport; len(v) > 0 {
+ copy(pk.TransportHeader().Push(len(v)), v)
+ }
+ if v := test.network; len(v) > 0 {
+ copy(pk.NetworkHeader().Push(len(v)), v)
+ }
+ if v := test.link; len(v) > 0 {
+ copy(pk.LinkHeader().Push(len(v)), v)
+ }
+
+ // Check the after values for packet.
+ if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want {
+ t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want {
+ t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), allHdrSize; got != want {
+ t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
+ }
+ if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
+ t.Errorf("After pk.Size() = %d, want %d", got, want)
+ }
+ checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data)
+ checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
+ concatViews(test.link, test.network, test.transport, test.data))
+ // Check the after values for each header.
+ checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
+ checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
+ checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
+ // Check the after values for PayloadSince.
+ checkViewEqual(t, "After PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(test.link, test.network, test.transport, test.data))
+ checkViewEqual(t, "After PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(test.network, test.transport, test.data))
+ checkViewEqual(t, "After PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(test.transport, test.data))
+ })
+ }
+}
+
+func TestPacketHeaderConsume(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ data []byte
+ link int
+ network int
+ transport int
+ }{
+ {
+ name: "parse L2 packet",
+ data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)),
+ link: 10,
+ network: 20,
+ transport: 30,
+ },
+ {
+ name: "parse L3 packet",
+ data: concatViews(makeView(20), makeView(30), makeView(40)),
+ network: 20,
+ transport: 30,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
+ })
+
+ // Check the initial values for packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ Data: buffer.View(test.data).ToVectorisedView(),
+ })
+
+ // Consume headers.
+ if size := test.link; size > 0 {
+ if _, ok := pk.LinkHeader().Consume(size); !ok {
+ t.Fatalf("pk.LinkHeader().Consume() = false, want true")
+ }
+ }
+ if size := test.network; size > 0 {
+ if _, ok := pk.NetworkHeader().Consume(size); !ok {
+ t.Fatalf("pk.NetworkHeader().Consume() = false, want true")
+ }
+ }
+ if size := test.transport; size > 0 {
+ if _, ok := pk.TransportHeader().Consume(size); !ok {
+ t.Fatalf("pk.TransportHeader().Consume() = false, want true")
+ }
+ }
+
+ allHdrSize := test.link + test.network + test.transport
+
+ // Check the after values for packet.
+ if got, want := pk.ReservedHeaderBytes(), 0; got != want {
+ t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), 0; got != want {
+ t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), allHdrSize; got != want {
+ t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
+ }
+ if got, want := pk.Size(), len(test.data); got != want {
+ t.Errorf("After pk.Size() = %d, want %d", got, want)
+ }
+ // After state of pk.
+ var (
+ link = test.data[:test.link]
+ network = test.data[test.link:][:test.network]
+ transport = test.data[test.link+test.network:][:test.transport]
+ payload = test.data[allHdrSize:]
+ )
+ checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload)
+ checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
+ // Check the after values for each header.
+ checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
+ checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
+ checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
+ // Check the after values for PayloadSince.
+ checkViewEqual(t, "After PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(link, network, transport, payload))
+ checkViewEqual(t, "After PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(network, transport, payload))
+ checkViewEqual(t, "After PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(transport, payload))
+ })
+ }
+}
+
+func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
+ data := makeView(10)
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
+ })
+
+ // Consume should fail if pkt.Data is too short.
+ if _, ok := pk.LinkHeader().Consume(11); ok {
+ t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false")
+ }
+ if _, ok := pk.NetworkHeader().Consume(11); ok {
+ t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false")
+ }
+ if _, ok := pk.TransportHeader().Consume(11); ok {
+ t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false")
+ }
+
+ // Check packet should look the same as initial packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ Data: buffer.View(data).ToVectorisedView(),
+ })
+}
+
+func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: headerSize * int(numHeaderType),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.TransportHeader(),
+ pk.NetworkHeader(),
+ pk.LinkHeader(),
+ } {
+ t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) {
+ h.Push(headerSize)
+
+ defer func() { recover() }()
+ h.Push(headerSize)
+ t.Fatal("Second push should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.LinkHeader(),
+ pk.NetworkHeader(),
+ pk.TransportHeader(),
+ } {
+ t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) {
+ if _, ok := h.Consume(headerSize); !ok {
+ t.Fatal("First consume should succeed")
+ }
+
+ defer func() { recover() }()
+ h.Consume(headerSize)
+ t.Fatal("Second consume should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderPushThenConsumePanics(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: headerSize * int(numHeaderType),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.TransportHeader(),
+ pk.NetworkHeader(),
+ pk.LinkHeader(),
+ } {
+ t.Run(h.typ.String(), func(t *testing.T) {
+ h.Push(headerSize)
+
+ defer func() { recover() }()
+ h.Consume(headerSize)
+ t.Fatal("Consume should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.LinkHeader(),
+ pk.NetworkHeader(),
+ pk.TransportHeader(),
+ } {
+ t.Run(h.typ.String(), func(t *testing.T) {
+ h.Consume(headerSize)
+
+ defer func() { recover() }()
+ h.Push(headerSize)
+ t.Fatal("Push should have panicked")
+ })
+ }
+}
+
+func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
+ t.Helper()
+ reserved := opts.ReserveHeaderBytes
+ if got, want := pk.ReservedHeaderBytes(), reserved; got != want {
+ t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), reserved; got != want {
+ t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), 0; got != want {
+ t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want)
+ }
+ data := opts.Data.ToView()
+ if got, want := pk.Size(), len(data); got != want {
+ t.Errorf("Initial pk.Size() = %d, want %d", got, want)
+ }
+ checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data)
+ checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
+ // Check the initial values for each header.
+ checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
+ checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
+ checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
+ // Check the initial valies for PayloadSince.
+ checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()), data)
+ checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()), data)
+ checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()), data)
+}
+
+func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
+ t.Helper()
+ checkViewEqual(t, name+".View()", h.View(), want)
+}
+
+func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
+ t.Helper()
+ if !bytes.Equal(got, want) {
+ t.Errorf("%s = %x, want %x", what, got, want)
+ }
+}
+
+func makeView(size int) buffer.View {
+ b := byte(size)
+ return bytes.Repeat([]byte{b}, size)
+}
+
+func concatViews(views ...buffer.View) buffer.View {
+ var all buffer.View
+ for _, v := range views {
+ all = append(all, v...)
+ }
+ return all
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 5cbc946b6..be9bd8042 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -15,9 +15,12 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -51,8 +54,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -121,6 +127,26 @@ type PacketEndpoint interface {
HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
+// UnknownDestinationPacketDisposition enumerates the possible return vaues from
+// HandleUnknownDestinationPacket().
+type UnknownDestinationPacketDisposition int
+
+const (
+ // UnknownDestinationPacketMalformed denotes that the packet was malformed
+ // and no further processing should be attempted other than updating
+ // statistics.
+ UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota
+
+ // UnknownDestinationPacketUnhandled tells the caller that the packet was
+ // well formed but that the issue was not handled and the stack should take
+ // the default action.
+ UnknownDestinationPacketUnhandled
+
+ // UnknownDestinationPacketHandled tells the caller that it should do
+ // no further processing.
+ UnknownDestinationPacketHandled
+)
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
@@ -128,10 +154,10 @@ type TransportProtocol interface {
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// NewRawEndpoint creates a new raw endpoint of the transport protocol.
- NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -143,24 +169,22 @@ type TransportProtocol interface {
ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
- // protocol but that don't match any existing endpoint. For example,
- // it is targeted at a port that have no listeners.
- //
- // The return value indicates whether the packet was well-formed (for
- // stats purposes only).
+ // protocol that don't match any existing endpoint. For example,
+ // it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // the issue.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option interface{}) *tcpip.Error
+ SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option interface{}) *tcpip.Error
+ Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -175,6 +199,25 @@ type TransportProtocol interface {
Parse(pkt *PacketBuffer) (ok bool)
}
+// TransportPacketDisposition is the result from attempting to deliver a packet
+// to the transport layer.
+type TransportPacketDisposition int
+
+const (
+ // TransportPacketHandled indicates that a transport packet was handled by the
+ // transport layer and callers need not take any further action.
+ TransportPacketHandled TransportPacketDisposition = iota
+
+ // TransportPacketProtocolUnreachable indicates that the transport
+ // protocol requested in the packet is not supported.
+ TransportPacketProtocolUnreachable
+
+ // TransportPacketDestinationPortUnreachable indicates that there weren't any
+ // listeners interested in the packet and the transport protocol has no means
+ // to notify the sender.
+ TransportPacketDestinationPortUnreachable
+)
+
// TransportDispatcher contains the methods used by the network stack to deliver
// packets to the appropriate transport endpoint after it has been handled by
// the network layer.
@@ -185,7 +228,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -222,9 +265,253 @@ type NetworkHeaderParams struct {
TOS uint8
}
+// GroupAddressableEndpoint is an endpoint that supports group addressing.
+//
+// An endpoint is considered to support group addressing when one or more
+// endpoints may associate themselves with the same identifier (group address).
+type GroupAddressableEndpoint interface {
+ // JoinGroup joins the spcified group.
+ //
+ // Returns true if the group was newly joined.
+ JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+
+ // LeaveGroup attempts to leave the specified group.
+ //
+ // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
+ LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+
+ // IsInGroup returns true if the endpoint is a member of the specified group.
+ IsInGroup(group tcpip.Address) bool
+}
+
+// PrimaryEndpointBehavior is an enumeration of an AddressEndpoint's primary
+// behavior.
+type PrimaryEndpointBehavior int
+
+const (
+ // CanBePrimaryEndpoint indicates the endpoint can be used as a primary
+ // endpoint for new connections with no local address. This is the
+ // default when calling NIC.AddAddress.
+ CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
+
+ // FirstPrimaryEndpoint indicates the endpoint should be the first
+ // primary endpoint considered. If there are multiple endpoints with
+ // this behavior, they are ordered by recency.
+ FirstPrimaryEndpoint
+
+ // NeverPrimaryEndpoint indicates the endpoint should never be a
+ // primary endpoint.
+ NeverPrimaryEndpoint
+)
+
+// AddressConfigType is the method used to add an address.
+type AddressConfigType int
+
+const (
+ // AddressConfigStatic is a statically configured address endpoint that was
+ // added by some user-specified action (adding an explicit address, joining a
+ // multicast group).
+ AddressConfigStatic AddressConfigType = iota
+
+ // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862
+ // section 5.5.3.
+ AddressConfigSlaac
+
+ // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as
+ // per RFC 4941. Temporary SLAAC addresses are short-lived and are not
+ // to be valid (or preferred) forever; hence the term temporary.
+ AddressConfigSlaacTemp
+)
+
+// AssignableAddressEndpoint is a reference counted address endpoint that may be
+// assigned to a NetworkEndpoint.
+type AssignableAddressEndpoint interface {
+ // AddressWithPrefix returns the endpoint's address.
+ AddressWithPrefix() tcpip.AddressWithPrefix
+
+ // IsAssigned returns whether or not the endpoint is considered bound
+ // to its NetworkEndpoint.
+ IsAssigned(allowExpired bool) bool
+
+ // IncRef increments this endpoint's reference count.
+ //
+ // Returns true if it was successfully incremented. If it returns false, then
+ // the endpoint is considered expired and should no longer be used.
+ IncRef() bool
+
+ // DecRef decrements this endpoint's reference count.
+ DecRef()
+}
+
+// AddressEndpoint is an endpoint representing an address assigned to an
+// AddressableEndpoint.
+type AddressEndpoint interface {
+ AssignableAddressEndpoint
+
+ // GetKind returns the address kind for this endpoint.
+ GetKind() AddressKind
+
+ // SetKind sets the address kind for this endpoint.
+ SetKind(AddressKind)
+
+ // ConfigType returns the method used to add the address.
+ ConfigType() AddressConfigType
+
+ // Deprecated returns whether or not this endpoint is deprecated.
+ Deprecated() bool
+
+ // SetDeprecated sets this endpoint's deprecated status.
+ SetDeprecated(bool)
+}
+
+// AddressKind is the kind of of an address.
+//
+// See the values of AddressKind for more details.
+type AddressKind int
+
+const (
+ // PermanentTentative is a permanent address endpoint that is not yet
+ // considered to be fully bound to an interface in the traditional
+ // sense. That is, the address is associated with a NIC, but packets
+ // destined to the address MUST NOT be accepted and MUST be silently
+ // dropped, and the address MUST NOT be used as a source address for
+ // outgoing packets. For IPv6, addresses are of this kind until NDP's
+ // Duplicate Address Detection (DAD) resolves. If DAD fails, the address
+ // is removed.
+ PermanentTentative AddressKind = iota
+
+ // Permanent is a permanent endpoint (vs. a temporary one) assigned to the
+ // NIC. Its reference count is biased by 1 to avoid removal when no route
+ // holds a reference to it. It is removed by explicitly removing the address
+ // from the NIC.
+ Permanent
+
+ // PermanentExpired is a permanent endpoint that had its address removed from
+ // the NIC, and it is waiting to be removed once no references to it are held.
+ //
+ // If the address is re-added before the endpoint is removed, its type
+ // changes back to Permanent.
+ PermanentExpired
+
+ // Temporary is an endpoint, created on a one-off basis to temporarily
+ // consider the NIC bound an an address that it is not explictiy bound to
+ // (such as a permanent address). Its reference count must not be biased by 1
+ // so that the address is removed immediately when references to it are no
+ // longer held.
+ //
+ // A temporary endpoint may be promoted to permanent if the address is added
+ // permanently.
+ Temporary
+)
+
+// IsPermanent returns true if the AddressKind represents a permanent address.
+func (k AddressKind) IsPermanent() bool {
+ switch k {
+ case Permanent, PermanentTentative:
+ return true
+ case Temporary, PermanentExpired:
+ return false
+ default:
+ panic(fmt.Sprintf("unrecognized address kind = %d", k))
+ }
+}
+
+// AddressableEndpoint is an endpoint that supports addressing.
+//
+// An endpoint is considered to support addressing when the endpoint may
+// associate itself with an identifier (address).
+type AddressableEndpoint interface {
+ // AddAndAcquirePermanentAddress adds the passed permanent address.
+ //
+ // Returns tcpip.ErrDuplicateAddress if the address exists.
+ //
+ // Acquires and returns the AddressEndpoint for the added address.
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error)
+
+ // RemovePermanentAddress removes the passed address if it is a permanent
+ // address.
+ //
+ // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed
+ // permanent address.
+ RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
+
+ // MainAddress returns the endpoint's primary permanent address.
+ MainAddress() tcpip.AddressWithPrefix
+
+ // AcquireAssignedAddress returns an address endpoint for the passed address
+ // that is considered bound to the endpoint, optionally creating a temporary
+ // endpoint if requested and no existing address exists.
+ //
+ // The returned endpoint's reference count is incremented.
+ //
+ // Returns nil if the specified address is not local to this endpoint.
+ AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint
+
+ // AcquireOutgoingPrimaryAddress returns a primary address that may be used as
+ // a source address when sending packets to the passed remote address.
+ //
+ // If allowExpired is true, expired addresses may be returned.
+ //
+ // The returned endpoint's reference count is incremented.
+ //
+ // Returns nil if a primary address is not available.
+ AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
+
+ // PrimaryAddresses returns the primary addresses.
+ PrimaryAddresses() []tcpip.AddressWithPrefix
+
+ // PermanentAddresses returns all the permanent addresses.
+ PermanentAddresses() []tcpip.AddressWithPrefix
+}
+
+// NDPEndpoint is a network endpoint that supports NDP.
+type NDPEndpoint interface {
+ NetworkEndpoint
+
+ // InvalidateDefaultRouter invalidates a default router discovered through
+ // NDP.
+ InvalidateDefaultRouter(tcpip.Address)
+}
+
+// NetworkInterface is a network interface.
+type NetworkInterface interface {
+ // ID returns the interface's ID.
+ ID() tcpip.NICID
+
+ // IsLoopback returns true if the interface is a loopback interface.
+ IsLoopback() bool
+
+ // Name returns the name of the interface.
+ //
+ // May return an empty string if the interface is not configured with a name.
+ Name() string
+
+ // Enabled returns true if the interface is enabled.
+ Enabled() bool
+
+ // LinkEndpoint returns the link endpoint backing the interface.
+ LinkEndpoint() LinkEndpoint
+}
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
+ AddressableEndpoint
+
+ // Enable enables the endpoint.
+ //
+ // Must only be called when the stack is in a state that allows the endpoint
+ // to send and receive packets.
+ //
+ // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled.
+ Enable() *tcpip.Error
+
+ // Enabled returns true if the endpoint is enabled.
+ Enabled() bool
+
+ // Disable disables the endpoint.
+ Disable()
+
// DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
// for this endpoint.
DefaultTTL() uint8
@@ -234,10 +521,6 @@ type NetworkEndpoint interface {
// minus the network endpoint max header length.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // underlying link-layer endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -245,8 +528,8 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol. It takes ownership of pkt. pkt.TransportHeader must have already
- // been set.
+ // protocol. It takes ownership of pkt. pkt.TransportHeader must have
+ // already been set.
WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
@@ -258,15 +541,6 @@ type NetworkEndpoint interface {
// header to the given destination address. It takes ownership of pkt.
WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
- // ID returns the network protocol endpoint ID.
- ID() *NetworkEndpointID
-
- // PrefixLen returns the network endpoint's subnet prefix length in bits.
- PrefixLen() int
-
- // NICID returns the id of the NIC this endpoint belongs to.
- NICID() tcpip.NICID
-
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
//
@@ -281,6 +555,17 @@ type NetworkEndpoint interface {
NetworkProtocolNumber() tcpip.NetworkProtocolNumber
}
+// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
+type ForwardingNetworkProtocol interface {
+ NetworkProtocol
+
+ // Forwarding returns the forwarding configuration.
+ Forwarding() bool
+
+ // SetForwarding sets the forwarding configuration.
+ SetForwarding(bool)
+}
+
// NetworkProtocol is the interface that needs to be implemented by network
// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
type NetworkProtocol interface {
@@ -300,17 +585,17 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error)
+ NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option interface{}) *tcpip.Error
+ SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option interface{}) *tcpip.Error
+ Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -329,8 +614,7 @@ type NetworkProtocol interface {
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -341,6 +625,16 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -420,8 +714,8 @@ type LinkEndpoint interface {
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
- // Attach will be called with a nil dispatcher if the receiver's associated
- // NIC is being removed.
+ // Attach is called with a nil dispatcher when the endpoint's NIC is being
+ // removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
@@ -436,6 +730,15 @@ type LinkEndpoint interface {
// Wait will not block if the endpoint hasn't started any goroutines
// yet, even if it might later.
Wait()
+
+ // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint.
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
+ ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -456,12 +759,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr.
- // The request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+ // the request on the local network if remoteLinkAddr is the zero value. The
+ // request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
@@ -471,7 +775,7 @@ type LinkAddressResolver interface {
ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
// LinkAddressProtocol returns the network protocol of the
- // addresses this this resolver can resolve.
+ // addresses this resolver can resolve.
LinkAddressProtocol() tcpip.NetworkProtocolNumber
}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index d65f8049e..cc39c9a6a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -42,17 +42,27 @@ type Route struct {
// NetProto is the network-layer protocol.
NetProto tcpip.NetworkProtocolNumber
- // ref a reference to the network endpoint through which the route
- // starts.
- ref *referencedNetworkEndpoint
-
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+
+ // nic is the NIC the route goes through.
+ nic *NIC
+
+ // addressEndpoint is the local address this route is associated with.
+ addressEndpoint AssignableAddressEndpoint
+
+ // linkCache is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkCache LinkAddressCache
+
+ // linkRes is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkRes LinkAddressResolver
}
// makeRoute initializes a new route. It takes ownership of the provided
-// reference to a network endpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route {
+// AssignableAddressEndpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
loop := PacketOut
if handleLocal && localAddr != "" && remoteAddr == localAddr {
loop = PacketLoop
@@ -62,29 +72,40 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
- return Route{
+ linkEP := nic.LinkEndpoint()
+ r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: localLinkAddr,
+ LocalLinkAddress: linkEP.LinkAddress(),
RemoteAddress: remoteAddr,
- ref: ref,
+ addressEndpoint: addressEndpoint,
+ nic: nic,
Loop: loop,
}
+
+ if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ r.linkRes = linkRes
+ r.linkCache = nic.stack
+ }
+ }
+
+ return r
}
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.ref.ep.NICID()
+ return r.nic.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.ref.ep.MaxHeaderLength()
+ return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.ref.nic.stack.Stats()
+ return r.nic.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -95,17 +116,23 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.ref.ep.Capabilities()
+ return r.nic.LinkEndpoint().Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.ref.ep.(GSOEndpoint); ok {
+ if gso, ok := r.nic.getNetworkEndpoint(r.NetProto).(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
}
+// ResolveWith immediately resolves a route with the specified remote link
+// address.
+func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
+ r.RemoteLinkAddress = addr
+}
+
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -131,7 +158,17 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
}
nextAddr = r.RemoteAddress
}
- linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+
+ if neigh := r.nic.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
+ if err != nil {
+ return ch, err
+ }
+ r.RemoteLinkAddress = entry.LinkAddr
+ return nil, nil
+ }
+
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
if err != nil {
return ch, err
}
@@ -145,7 +182,13 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
if nextAddr == "" {
nextAddr = r.RemoteAddress
}
- r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+
+ if neigh := r.nic.neigh; neigh != nil {
+ neigh.removeWaker(nextAddr, waker)
+ return
+ }
+
+ r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
@@ -153,102 +196,88 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
//
// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
- return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ if r.nic.neigh != nil {
+ return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
+ }
+ return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return tcpip.ErrInvalidEndpointState
}
// WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+ numBytes := pkt.Size()
- err := r.ref.ep.WritePacket(r, gso, params, pkt)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil {
+ return err
}
- return err
+
+ r.nic.stats.Tx.Packets.Increment()
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return 0, tcpip.ErrInvalidEndpointState
}
- // WritePackets takes ownership of pkt, calculate length first.
- numPkts := pkts.Len()
-
- n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
- }
- r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
-
+ n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Header.UsedLength()
- writtenBytes += pb.Data.Size()
+ writtenBytes += pb.Size()
}
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
return n, err
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.ref.isValidForOutgoing() {
+ if !r.nic.isValidForOutgoing(r.addressEndpoint) {
return tcpip.ErrInvalidEndpointState
}
// WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Data.Size()
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil {
return err
}
- r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ r.nic.stats.Tx.Packets.Increment()
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
return nil
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.ref.ep.DefaultTTL()
+ return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.ref.ep.MTU()
-}
-
-// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
-// network endpoint.
-func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return r.ref.ep.NetworkProtocolNumber()
+ return r.nic.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.ref != nil {
- r.ref.decRef()
- r.ref = nil
+ if r.addressEndpoint != nil {
+ r.addressEndpoint.DecRef()
+ r.addressEndpoint = nil
}
}
-// Clone Clone a route such that the original one can be released and the new
-// one will remain valid.
+// Clone clones the route.
func (r *Route) Clone() Route {
- if r.ref != nil {
- r.ref.incRef()
+ if r.addressEndpoint != nil {
+ _ = r.addressEndpoint.IncRef()
}
return *r
}
@@ -272,7 +301,30 @@ func (r *Route) MakeLoopedRoute() Route {
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.ref.stack()
+ return r.nic.stack
+}
+
+func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
+ if addr == header.IPv4Broadcast {
+ return true
+ }
+
+ subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
+ return subnet.IsBroadcast(addr)
+}
+
+// IsOutboundBroadcast returns true if the route is for an outbound broadcast
+// packet.
+func (r *Route) IsOutboundBroadcast() bool {
+ // Only IPv4 has a notion of broadcast.
+ return r.isV4Broadcast(r.RemoteAddress)
+}
+
+// IsInboundBroadcast returns true if the route is for an inbound broadcast
+// packet.
+func (r *Route) IsInboundBroadcast() bool {
+ // Only IPv4 has a notion of broadcast.
+ return r.isV4Broadcast(r.LocalAddress)
}
// ReverseRoute returns new route with given source and destination address.
@@ -283,7 +335,10 @@ func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
LocalLinkAddress: r.RemoteLinkAddress,
RemoteAddress: src,
RemoteLinkAddress: r.LocalLinkAddress,
- ref: r.ref,
Loop: r.Loop,
+ addressEndpoint: r.addressEndpoint,
+ nic: r.nic,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cdcfb8321..0bf20c0e1 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -73,6 +73,16 @@ type TCPCubicState struct {
WEst float64
}
+// TCPRACKState is used to hold a copy of the internal RACK state when the
+// TCPProbeFunc is invoked.
+type TCPRACKState struct {
+ XmitTime time.Time
+ EndSequence seqnum.Value
+ FACK seqnum.Value
+ RTT time.Duration
+ Reord bool
+}
+
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
type TCPEndpointID struct {
// LocalPort is the local port associated with the endpoint.
@@ -134,10 +144,7 @@ type TCPReceiverState struct {
// PendingBufUsed is the number of bytes pending in the receive
// queue.
- PendingBufUsed seqnum.Size
-
- // PendingBufSize is the size of the socket receive buffer.
- PendingBufSize seqnum.Size
+ PendingBufUsed int
}
// TCPSenderState holds a copy of the internal state of the sender for
@@ -212,6 +219,9 @@ type TCPSenderState struct {
// Cubic holds the state related to CUBIC congestion control.
Cubic TCPCubicState
+
+ // RACKState holds the state related to RACK loss detection algorithm.
+ RACKState TCPRACKState
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
@@ -235,7 +245,7 @@ type RcvBufAutoTuneParams struct {
// was started.
MeasureTime time.Time
- // CopiedBytes is the number of bytes copied to userspace since
+ // CopiedBytes is the number of bytes copied to user space since
// this measure began.
CopiedBytes int
@@ -353,38 +363,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 {
return atomic.AddUint64((*uint64)(u), 1)
}
-// NICNameFromID is a function that returns a stable name for the specified NIC,
-// even if different NIC IDs are used to refer to the same NIC in different
-// program runs. It is used when generating opaque interface identifiers (IIDs).
-// If the NIC was created with a name, it will be passed to NICNameFromID.
-//
-// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
-// generated for the same prefix on differnt NICs.
-type NICNameFromID func(tcpip.NICID, string) string
-
-// OpaqueInterfaceIdentifierOptions holds the options related to the generation
-// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
-type OpaqueInterfaceIdentifierOptions struct {
- // NICNameFromID is a function that returns a stable name for a specified NIC,
- // even if the NIC ID changes over time.
- //
- // Must be specified to generate the opaque IID.
- NICNameFromID NICNameFromID
-
- // SecretKey is a pseudo-random number used as the secret key when generating
- // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
- // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
- // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
- // change between program runs, unless explicitly changed.
- //
- // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
- // MUST NOT be modified after Stack is created.
- //
- // May be nil, but a nil value is highly discouraged to maintain
- // some level of randomness between nodes.
- SecretKey []byte
-}
-
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -402,10 +380,12 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
- cleanupEndpoints map[TransportEndpoint]struct{}
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+
+ // cleanupEndpointsMu protects cleanupEndpoints.
+ cleanupEndpointsMu sync.Mutex
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -416,7 +396,7 @@ type Stack struct {
// If not nil, then any new endpoints will have this probe function
// invoked everytime they receive a TCP segment.
- tcpProbeFunc TCPProbeFunc
+ tcpProbeFunc atomic.Value // TCPProbeFunc
// clock is used to generate user-visible times.
clock tcpip.Clock
@@ -425,6 +405,7 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
+ // TODO(gvisor.dev/issue/170): S/R this field.
tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -441,29 +422,20 @@ type Stack struct {
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
- // ndpConfigs is the default NDP configurations used by interfaces.
- ndpConfigs NDPConfigurations
+ // nudConfigs is the default NUD configurations used by interfaces.
+ nudConfigs NUDConfigurations
- // autoGenIPv6LinkLocal determines whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
- autoGenIPv6LinkLocal bool
+ // useNeighborCache indicates whether ARP and NDP packets should be handled
+ // by the NIC's neighborCache instead of linkAddrCache.
+ useNeighborCache bool
- // ndpDisp is the NDP event dispatcher that is used to send the netstack
- // integrator NDP related events.
- ndpDisp NDPDispatcher
+ // nudDisp is the NUD event dispatcher that is used to send the netstack
+ // integrator NUD related events.
+ nudDisp NUDDispatcher
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // opaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
- opaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
- // tempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- tempIIDSeed []byte
-
// forwarder holds the packets that wait for their link-address resolutions
// to complete, and forwards them when each resolution is done.
forwarder *forwardQueue
@@ -486,13 +458,25 @@ type UniqueID interface {
UniqueID() uint64
}
+// NetworkProtocolFactory instantiates a network protocol.
+//
+// NetworkProtocolFactory must not attempt to modify the stack, it may only
+// query the stack.
+type NetworkProtocolFactory func(*Stack) NetworkProtocol
+
+// TransportProtocolFactory instantiates a transport protocol.
+//
+// TransportProtocolFactory must not attempt to modify the stack, it may only
+// query the stack.
+type TransportProtocolFactory func(*Stack) TransportProtocol
+
// Options contains optional Stack configuration.
type Options struct {
// NetworkProtocols lists the network protocols to enable.
- NetworkProtocols []NetworkProtocol
+ NetworkProtocols []NetworkProtocolFactory
// TransportProtocols lists the transport protocols to enable.
- TransportProtocols []TransportProtocol
+ TransportProtocols []TransportProtocolFactory
// Clock is an optional clock source used for timestampping packets.
//
@@ -510,60 +494,30 @@ type Options struct {
// UniqueID is an optional generator of unique identifiers.
UniqueID UniqueID
- // NDPConfigs is the default NDP configurations used by interfaces.
- //
- // By default, NDPConfigs will have a zero value for its
- // DupAddrDetectTransmits field, implying that DAD will not be performed
- // before assigning an address to a NIC.
- NDPConfigs NDPConfigurations
-
- // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
- // auto-generate an IPv6 link-local address for newly enabled non-loopback
- // NICs.
- //
- // Note, setting this to true does not mean that a link-local address
- // will be assigned right away, or at all. If Duplicate Address Detection
- // is enabled, an address will only be assigned if it successfully resolves.
- // If it fails, no further attempt will be made to auto-generate an IPv6
- // link-local address.
- //
- // The generated link-local address will follow RFC 4291 Appendix A
- // guidelines.
- AutoGenIPv6LinkLocal bool
+ // NUDConfigs is the default NUD configurations used by interfaces.
+ NUDConfigs NUDConfigurations
+
+ // UseNeighborCache indicates whether ARP and NDP packets should be handled
+ // by the Neighbor Unreachability Detection (NUD) state machine. This flag
+ // also enables the APIs for inspecting and modifying the neighbor table via
+ // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor,
+ // and ClearNeighbors.
+ UseNeighborCache bool
- // NDPDisp is the NDP event dispatcher that an integrator can provide to
- // receive NDP related events.
- NDPDisp NDPDispatcher
+ // NUDDisp is the NUD event dispatcher that an integrator can provide to
+ // receive NUD related events.
+ NUDDisp NUDDispatcher
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
- // OpaqueIIDOpts hold the options for generating opaque interface
- // identifiers (IIDs) as outlined by RFC 7217.
- OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
-
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
// returned by rand.Read().
//
// RandSource must be thread-safe.
RandSource mathrand.Source
-
- // TempIIDSeed is used to seed the initial temporary interface identifier
- // history value used to generate IIDs for temporary SLAAC addresses.
- //
- // Temporary SLAAC adresses are short-lived addresses which are unpredictable
- // and random from the perspective of other nodes on the network. It is
- // recommended that the seed be a random byte buffer of at least
- // header.IIDSize bytes to make sure that temporary SLAAC addresses are
- // sufficiently random. It should follow minimum randomness requirements for
- // security as outlined by RFC 4086.
- //
- // Note: using a nil value, the same seed across netstack program runs, or a
- // seed that is too small would reduce randomness and increase predictability,
- // defeating the purpose of temporary SLAAC addresses.
- TempIIDSeed []byte
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -666,31 +620,28 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
- // Make sure opts.NDPConfigs contains valid values only.
- opts.NDPConfigs.validate()
+ opts.NUDConfigs.resetInvalidFields()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
- nics: make(map[tcpip.NICID]*NIC),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: generateRandUint32(),
- ndpConfigs: opts.NDPConfigs,
- autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
- uniqueIDGenerator: opts.UniqueID,
- ndpDisp: opts.NDPDisp,
- opaqueIIDOpts: opts.OpaqueIIDOpts,
- tempIIDSeed: opts.TempIIDSeed,
- forwarder: newForwardQueue(),
- randomGenerator: mathrand.New(randSrc),
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ nudConfigs: opts.NUDConfigs,
+ useNeighborCache: opts.UseNeighborCache,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ forwarder: newForwardQueue(),
+ randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -704,7 +655,8 @@ func New(opts Options) *Stack {
}
// Add specified network protocols.
- for _, netProto := range opts.NetworkProtocols {
+ for _, netProtoFactory := range opts.NetworkProtocols {
+ netProto := netProtoFactory(s)
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -712,7 +664,8 @@ func New(opts Options) *Stack {
}
// Add specified transport protocols.
- for _, transProto := range opts.TransportProtocols {
+ for _, transProtoFactory := range opts.TransportProtocols {
+ transProto := transProtoFactory(s)
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
@@ -727,6 +680,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -736,7 +694,7 @@ func (s *Stack) UniqueID() uint64 {
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -753,7 +711,7 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op
// if err != nil {
// ...
// }
-func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -765,7 +723,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -780,7 +738,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
// ...
// }
-func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -800,9 +758,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
@@ -813,46 +772,37 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables the packet forwarding between NICs.
-//
-// When forwarding becomes enabled, any host-only state on all NICs will be
-// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
-// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
-// Solicitations will be stopped.
-func (s *Stack) SetForwarding(enable bool) {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.Lock()
- defer s.mu.Unlock()
+// SetForwarding enables or disables packet forwarding between NICs for the
+// passed protocol.
+func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error {
+ protocol, ok := s.networkProtocols[protocolNum]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
- // If forwarding status didn't change, do nothing further.
- if s.forwarding == enable {
- return
+ forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ if !ok {
+ return tcpip.ErrNotSupported
}
- s.forwarding = enable
+ forwardingProtocol.SetForwarding(enable)
+ return nil
+}
- // If this stack does not support IPv6, do nothing further.
- if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok {
- return
+// Forwarding returns true if packet forwarding between NICs is enabled for the
+// passed protocol.
+func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
+ protocol, ok := s.networkProtocols[protocolNum]
+ if !ok {
+ return false
}
- if enable {
- for _, nic := range s.nics {
- nic.becomeIPv6Router()
- }
- } else {
- for _, nic := range s.nics {
- nic.becomeIPv6Host()
- }
+ forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
+ if !ok {
+ return false
}
-}
-// Forwarding returns if the packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding() bool {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.forwarding
+ return forwardingProtocol.Forwarding()
}
// SetRouteTable assigns the route table to be used by this stack. It
@@ -887,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewEndpoint(s, network, waiterQueue)
+ return t.proto.NewEndpoint(network, waiterQueue)
}
// NewRawEndpoint creates a new raw transport layer endpoint of the given
@@ -907,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewRawEndpoint(s, network, waiterQueue)
+ return t.proto.NewRawEndpoint(network, waiterQueue)
}
// NewPacketEndpoint creates a new packet endpoint listening for the given
@@ -1014,7 +964,8 @@ func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
- return nic.disable()
+ nic.disable()
+ return nil
}
// CheckNIC checks if a NIC is usable.
@@ -1027,7 +978,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
return false
}
- return nic.enabled()
+ return nic.Enabled()
}
// RemoveNIC removes NIC and all related routes from the network stack.
@@ -1064,19 +1015,6 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
return nic.remove()
}
-// NICAddressRanges returns a map of NICIDs to their associated subnets.
-func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nics := map[tcpip.NICID][]tcpip.Subnet{}
-
- for id, nic := range s.nics {
- nics[id] = append(nics[id], nic.AddressRanges()...)
- }
- return nics
-}
-
// NICInfo captures the name and addresses assigned to a NIC.
type NICInfo struct {
Name string
@@ -1094,6 +1032,11 @@ type NICInfo struct {
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
+
+ // ARPHardwareType holds the ARP Hardware type of the NIC. This is the
+ // value sent in haType field of an ARP Request sent by this NIC and the
+ // value expected in the haType field of an ARP response.
+ ARPHardwareType header.ARPHardwareType
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -1113,18 +1056,19 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.enabled(),
+ Running: nic.Enabled(),
Promiscuous: nic.isPromiscuousMode(),
- Loopback: nic.isLoopback(),
+ Loopback: nic.IsLoopback(),
}
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.PrimaryAddresses(),
+ ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
Context: nic.context,
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
@@ -1178,41 +1122,12 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
- return nic.AddAddress(protocolAddress, peb)
-}
-
-// AddAddressRange adds a range of addresses to the specified NIC. The range is
-// given by a subnet address, and all addresses contained in the subnet are
-// used except for the subnet address itself and the subnet's broadcast
-// address.
-func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic, ok := s.nics[id]; ok {
- nic.AddAddressRange(protocol, subnet)
- return nil
- }
-
- return tcpip.ErrUnknownNICID
-}
-
-// RemoveAddressRange removes the range of addresses from the specified NIC.
-func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic, ok := s.nics[id]; ok {
- nic.RemoveAddressRange(subnet)
- return nil
- }
-
- return tcpip.ErrUnknownNICID
+ return nic.addAddress(protocolAddress, peb)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1222,7 +1137,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- return nic.RemoveAddress(addr)
+ return nic.removeAddress(addr)
}
return tcpip.ErrUnknownNICID
@@ -1236,7 +1151,7 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
for id, nic := range s.nics {
- nics[id] = nic.AllAddresses()
+ nics[id] = nic.allPermanentAddresses()
}
return nics
}
@@ -1258,7 +1173,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return nic.primaryAddress(protocol), nil
}
-func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
if len(localAddr) == 0 {
return nic.primaryEndpoint(netProto, remoteAddr)
}
@@ -1271,13 +1186,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isBroadcast := remoteAddr == header.IPv4Broadcast
+ isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok && nic.enabled() {
- if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
+ if nic, ok := s.nics[id]; ok && nic.Enabled() {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
}
}
} else {
@@ -1285,18 +1200,23 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
- if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
+ if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
- remoteAddr = ref.ep.ID().LocalAddress
+ remoteAddr = addressEndpoint.AddressWithPrefix().Address
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
- if needRoute {
- r.NextHop = route.Gateway
+ r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
+ if len(route.Gateway) > 0 {
+ if needRoute {
+ r.NextHop = route.Gateway
+ }
+ } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
+
return r, nil
}
}
@@ -1326,26 +1246,25 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
// If a NIC is specified, we try to find the address there only.
if nicID != 0 {
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return 0
}
- ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if ref == nil {
+ addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
return 0
}
- ref.decRef()
+ addressEndpoint.DecRef()
return nic.id
}
// Go through all the NICs.
for _, nic := range s.nics {
- ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if ref != nil {
- ref.decRef()
+ if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
return nic.id
}
}
@@ -1358,8 +1277,8 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1374,8 +1293,8 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1407,8 +1326,33 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
-// RemoveWaker implements LinkAddressCache.RemoveWaker.
+// Neighbors returns all IP to MAC address associations.
+func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return nil, tcpip.ErrUnknownNICID
+ }
+
+ return nic.neighbors()
+}
+
+// RemoveWaker removes a waker that has been added when link resolution for
+// addr was requested.
func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ if s.useNeighborCache {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if ok {
+ nic.removeWaker(addr, waker)
+ }
+ return
+ }
+
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1418,6 +1362,47 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.
}
}
+// AddStaticNeighbor statically associates an IP address to a MAC address.
+func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.addStaticNeighbor(addr, linkAddr)
+}
+
+// RemoveNeighbor removes an IP to MAC address association previously created
+// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there
+// is no association with the provided address.
+func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.removeNeighbor(addr)
+}
+
+// ClearNeighbors removes all IP to MAC address associations.
+func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.clearNeighbors()
+}
+
// RegisterTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
@@ -1441,10 +1426,9 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
+ s.cleanupEndpointsMu.Lock()
s.cleanupEndpoints[ep] = struct{}{}
+ s.cleanupEndpointsMu.Unlock()
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
@@ -1452,9 +1436,9 @@ func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcp
// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
// stage.
func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
delete(s.cleanupEndpoints, ep)
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
}
// FindTransportEndpoint finds an endpoint that most closely matches the provided
@@ -1497,23 +1481,23 @@ func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
// CleanupEndpoints returns endpoints currently in the cleanup state.
func (s *Stack) CleanupEndpoints() []TransportEndpoint {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
for e := range s.cleanupEndpoints {
es = append(es, e)
}
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
return es
}
// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
// for restoring a stack after a save.
func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
for _, e := range es {
s.cleanupEndpoints[e] = struct{}{}
}
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
}
// Close closes all currently registered transport endpoints.
@@ -1708,18 +1692,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra
// guarantee provided on which probe will be invoked. Ideally this should only
// be called once per stack.
func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
- s.mu.Lock()
- s.tcpProbeFunc = probe
- s.mu.Unlock()
+ s.tcpProbeFunc.Store(probe)
}
// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
// otherwise.
func (s *Stack) GetTCPProbe() TCPProbeFunc {
- s.mu.Lock()
- p := s.tcpProbeFunc
- s.mu.Unlock()
- return p
+ p := s.tcpProbeFunc.Load()
+ if p == nil {
+ return nil
+ }
+ return p.(TCPProbeFunc)
}
// RemoveTCPProbe removes an installed TCP probe.
@@ -1728,9 +1711,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc {
// have a probe attached. Endpoints already created will continue to invoke
// TCP probe.
func (s *Stack) RemoveTCPProbe() {
- s.mu.Lock()
- s.tcpProbeFunc = nil
- s.mu.Unlock()
+ // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics.
+ s.tcpProbeFunc.Store(TCPProbeFunc(nil))
}
// JoinGroup joins the given multicast group on the given NIC.
@@ -1751,7 +1733,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
- return nic.leaveGroup(multicastAddr)
+ return nic.leaveGroup(protocol, multicastAddr)
}
return tcpip.ErrUnknownNICID
}
@@ -1803,70 +1785,47 @@ func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
-// IsAddrTentative returns true if addr is tentative on the NIC with ID id.
-//
-// Note that if addr is not associated with a NIC with id ID, then this
-// function will return false. It will only return true if the address is
-// associated with the NIC AND it is tentative.
-func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) {
- s.mu.RLock()
- defer s.mu.RUnlock()
+// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol
+// number installed on the specified NIC.
+func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
- nic, ok := s.nics[id]
+ nic, ok := s.nics[nicID]
if !ok {
- return false, tcpip.ErrUnknownNICID
+ return nil, tcpip.ErrUnknownNICID
}
- return nic.isAddrTentative(addr), nil
+ return nic.getNetworkEndpoint(proto), nil
}
-// DupTentativeAddrDetected attempts to inform the NIC with ID id that a
-// tentative addr on it is a duplicate on a link.
-func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
- s.mu.Lock()
- defer s.mu.Unlock()
-
+// NUDConfigurations gets the per-interface NUD configurations.
+func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) {
+ s.mu.RLock()
nic, ok := s.nics[id]
+ s.mu.RUnlock()
+
if !ok {
- return tcpip.ErrUnknownNICID
+ return NUDConfigurations{}, tcpip.ErrUnknownNICID
}
- return nic.dupTentativeAddrDetected(addr)
+ return nic.nudConfigs()
}
-// SetNDPConfigurations sets the per-interface NDP configurations on the NIC
-// with ID id to c.
+// SetNUDConfigurations sets the per-interface NUD configurations.
//
-// Note, if c contains invalid NDP configuration values, it will be fixed to
+// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error {
- s.mu.Lock()
- defer s.mu.Unlock()
-
+func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error {
+ s.mu.RLock()
nic, ok := s.nics[id]
- if !ok {
- return tcpip.ErrUnknownNICID
- }
-
- nic.setNDPConfigs(c)
-
- return nil
-}
-
-// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
-// message that it needs to handle.
-func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RUnlock()
- nic, ok := s.nics[id]
if !ok {
return tcpip.ErrUnknownNICID
}
- nic.handleNDPRA(ip, ra)
-
- return nil
+ return nic.setNUDConfigs(c)
}
// Seed returns a 32 bit value that can be used as a seed value for port
@@ -1906,28 +1865,24 @@ func generateRandInt64() int64 {
// FindNetworkEndpoint returns the network endpoint for the given address.
func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
for _, nic := range s.nics {
- id := NetworkEndpointID{address}
-
- if ref, ok := nic.mu.endpoints[id]; ok {
- nic.mu.RLock()
- defer nic.mu.RUnlock()
-
- // An endpoint with this id exists, check if it can be
- // used and return it.
- return ref.ep, nil
+ addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ continue
}
+ addressEndpoint.DecRef()
+ return nic.getNetworkEndpoint(netProto), nil
}
return nil, tcpip.ErrBadAddress
}
-// FindNICNameFromID returns the name of the nic for the given NICID.
+// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
@@ -1936,3 +1891,8 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
return nic.Name()
}
+
+// NewJob returns a new tcpip.Job using the stack's clock.
+func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7657a4101..aa20f750b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,18 +21,21 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
- "strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -66,40 +69,53 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
+ stack.AddressableEndpointState
+
+ mu struct {
+ sync.RWMutex
+
+ enabled bool
+ }
+
nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
ep stack.LinkEndpoint
}
-func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = true
+ return nil
}
-func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
+func (f *fakeNetworkEndpoint) Enabled() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.enabled
}
-func (f *fakeNetworkEndpoint) PrefixLen() int {
- return f.prefixLen
+func (f *fakeNetworkEndpoint) Disable() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = false
}
-func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
- return 123
+func (f *fakeNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
-func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
- return &f.id
+func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
}
func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
+ f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
// Handle control packets.
- if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
@@ -115,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -126,10 +142,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto
return 0
}
-func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -140,10 +152,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
- pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
- pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
- pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
- pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
+ hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ hdr[dstAddrOffset] = r.RemoteAddress[0]
+ hdr[srcAddrOffset] = r.LocalAddress[0]
+ hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
f.HandlePacket(r, pkt)
@@ -164,16 +176,8 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack
return tcpip.ErrNotSupported
}
-func (*fakeNetworkEndpoint) Close() {}
-
-type fakeNetGoodOption bool
-
-type fakeNetBadOption bool
-
-type fakeNetInvalidValueOption int
-
-type fakeNetOptions struct {
- good bool
+func (f *fakeNetworkEndpoint) Close() {
+ f.AddressableEndpointState.Cleanup()
}
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
@@ -182,7 +186,12 @@ type fakeNetOptions struct {
type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
- opts fakeNetOptions
+ defaultTTL uint8
+
+ mu struct {
+ sync.RWMutex
+ forwarding bool
+ }
}
func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -205,57 +214,67 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- return &fakeNetworkEndpoint{
- nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
+func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &fakeNetworkEndpoint{
+ nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
- ep: ep,
- }, nil
+ ep: nic.LinkEndpoint(),
+ }
+ e.AddressableEndpointState.Init(e)
+ return e
}
-func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case fakeNetGoodOption:
- f.opts.good = bool(v)
+ case *tcpip.DefaultTTLOption:
+ f.defaultTTL = uint8(*v)
return nil
- case fakeNetInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
+func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *fakeNetGoodOption:
- *v = fakeNetGoodOption(f.opts.good)
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(f.defaultTTL)
return nil
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// Close implements TransportProtocol.Close.
+// Close implements NetworkProtocol.Close.
func (*fakeNetworkProtocol) Close() {}
-// Wait implements TransportProtocol.Wait.
+// Wait implements NetworkProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
-// Parse implements TransportProtocol.Parse.
+// Parse implements NetworkProtocol.Parse.
func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
- hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen)
if !ok {
return 0, false, false
}
- pkt.NetworkHeader = hdr
- pkt.Data.TrimFront(fakeNetHeaderLen)
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
-func fakeNetFactory() stack.NetworkProtocol {
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (f *fakeNetworkProtocol) Forwarding() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (f *fakeNetworkProtocol) SetForwarding(v bool) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.forwarding = v
+}
+
+func fakeNetFactory(*stack.Stack) stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -276,12 +295,23 @@ func (l *linkEPWithMockedAttach) isAttached() bool {
return l.attached
}
+// Checks to see if list contains an address.
+func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool {
+ for _, i := range list {
+ if i == item {
+ return true
+ }
+ }
+
+ return false
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
ep := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -301,9 +331,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[dstAddrOffset] = 3
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -313,9 +343,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[dstAddrOffset] = 1
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -325,9 +355,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[dstAddrOffset] = 2
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -336,9 +366,9 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -348,9 +378,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -369,11 +399,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
}
func send(r stack.Route, payload buffer.View) *tcpip.Error {
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ }))
}
func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
@@ -428,9 +457,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if got := fakeNet.PacketCount(localAddrByte); got != want {
t.Errorf("receive packet count: got = %d, want %d", got, want)
}
@@ -442,7 +471,7 @@ func TestNetworkSend(t *testing.T) {
// existing nic.
ep := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
@@ -469,7 +498,7 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
@@ -569,7 +598,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
e := linkEPWithMockedAttach{
@@ -588,7 +617,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
func TestDisableUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
@@ -600,7 +629,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
e := loopback.New()
@@ -647,7 +676,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
func TestRemoveUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
@@ -659,7 +688,7 @@ func TestRemoveNIC(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
e := linkEPWithMockedAttach{
@@ -720,7 +749,7 @@ func TestRouteWithDownNIC(t *testing.T) {
setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(1, defaultMTU, "")
@@ -867,9 +896,9 @@ func TestRouteWithDownNIC(t *testing.T) {
// Writes with Routes that use NIC1 after being brought up should
// succeed.
//
- // TODO(b/147015577): Should we instead completely invalidate all
- // Routes that were bound to a NIC that was brought down at some
- // point?
+ // TODO(gvisor.dev/issue/1491): Should we instead completely
+ // invalidate all Routes that were bound to a NIC that was brought
+ // down at some point?
if err := upFn(s, nicID1); err != nil {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
@@ -886,7 +915,7 @@ func TestRoutes(t *testing.T) {
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
@@ -966,7 +995,7 @@ func TestAddressRemoval(t *testing.T) {
remoteAddr := tcpip.Address("\x02")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1013,7 +1042,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
remoteAddr := tcpip.Address("\x02")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1104,7 +1133,7 @@ func TestEndpointExpiration(t *testing.T) {
for _, spoofing := range []bool{true, false} {
t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1262,7 +1291,7 @@ func TestEndpointExpiration(t *testing.T) {
func TestPromiscuousMode(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1314,7 +1343,7 @@ func TestSpoofingWithAddress(t *testing.T) {
dstAddr := tcpip.Address("\x03")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1380,7 +1409,7 @@ func TestSpoofingNoAddress(t *testing.T) {
dstAddr := tcpip.Address("\x02")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1443,7 +1472,7 @@ func verifyRoute(gotRoute, wantRoute stack.Route) error {
func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1486,7 +1515,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// Create a new stack with two NICs.
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1587,7 +1616,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
@@ -1642,239 +1671,24 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
-// Add a range of addresses, then check that a packet is delivered.
-func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- const localAddrByte byte = 0x01
- buf[dstAddrOffset] = localAddrByte
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
- if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- testRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
-func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
- t.Helper()
-
- // Loop over all addresses and check them.
- numOfAddresses := 1 << uint(8-subnet.Prefix())
- if numOfAddresses < 1 || numOfAddresses > 255 {
- t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
- }
-
- addrBytes := []byte(subnet.ID())
- for i := 0; i < numOfAddresses; i++ {
- addr := tcpip.Address(addrBytes)
- wantNicID := nicID
- // The subnet and broadcast addresses are skipped.
- if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() {
- wantNicID = 0
- }
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID)
- }
- addrBytes[0]++
- }
-
- // Trying the next address should always fail since it is outside the range.
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
- }
-}
-
-// Set a range of addresses, then remove it again, and check at each step that
-// CheckLocalAddress returns the correct NIC for each address or zero if not
-// existent.
-func TestCheckLocalAddressForSubnet(t *testing.T) {
- const nicID tcpip.NICID = 1
+func TestNetworkOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{},
})
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}})
- }
-
- subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
+ opt := tcpip.DefaultTTLOption(5)
+ if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil {
+ t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err)
}
- testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
-
- if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
+ var optGot tcpip.DefaultTTLOption
+ if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil {
+ t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err)
}
- testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */)
-
- if err := s.RemoveAddressRange(nicID, subnet); err != nil {
- t.Fatal("RemoveAddressRange failed:", err)
- }
-
- testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
-}
-
-// Set a range of addresses, then send a packet to a destination outside the
-// range and then check it doesn't get delivered.
-func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- const localAddrByte byte = 0x01
- buf[dstAddrOffset] = localAddrByte
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
- if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
-func TestNetworkOptions(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{},
- })
-
- // Try an unsupported network protocol.
- if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.NetworkProtocol)
- }{
- {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
- t.Helper()
- fakeNet := p.(*fakeNetworkProtocol)
- if fakeNet.opts.good != true {
- t.Fatalf("fakeNet.opts.good = false, want = true")
- }
- var v fakeNetGoodOption
- if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
- }
- }},
- {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
- }
- }
-}
-
-func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool {
- ranges, ok := s.NICAddressRanges()[id]
- if !ok {
- return false
- }
- for _, r := range ranges {
- if r == addrRange {
- return true
- }
- }
- return false
-}
-
-func TestAddresRangeAddRemove(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addr := tcpip.Address("\x01\x01\x01\x01")
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- addrRange, err := tcpip.NewSubnet(addr, mask)
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-
- if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-
- if err := s.RemoveAddressRange(1, addrRange); err != nil {
- t.Fatal("RemoveAddressRange failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
+ if opt != optGot {
+ t.Errorf("got optGot = %d, want = %d", optGot, opt)
}
}
@@ -1886,7 +1700,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1953,7 +1767,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
func TestGetMainNICAddressAddRemove(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -2038,7 +1852,7 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -2065,7 +1879,7 @@ func TestAddAddress(t *testing.T) {
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -2099,7 +1913,7 @@ func TestAddProtocolAddress(t *testing.T) {
func TestAddAddressWithOptions(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -2130,7 +1944,7 @@ func TestAddAddressWithOptions(t *testing.T) {
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -2251,7 +2065,7 @@ func TestCreateNICWithOptions(t *testing.T) {
func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -2271,9 +2085,9 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -2318,9 +2132,9 @@ func TestNICForwarding(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
@@ -2353,9 +2167,9 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to dstAddr.
buf := buffer.NewView(30)
buf[dstAddrOffset] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
pkt, ok := ep2.Read()
if !ok {
@@ -2363,8 +2177,8 @@ func TestNICForwarding(t *testing.T) {
}
// Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
- t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
+ if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
+ t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
}
// Test that forwarding increments Tx stats correctly.
@@ -2442,7 +2256,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName string
autoGen bool
linkAddr tcpip.LinkAddress
- iidOpts stack.OpaqueInterfaceIdentifierOptions
+ iidOpts ipv6.OpaqueInterfaceIdentifierOptions
shouldGen bool
expectedAddr tcpip.Address
}{
@@ -2458,7 +2272,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "nic1",
autoGen: false,
linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:],
},
@@ -2503,7 +2317,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "nic1",
autoGen: true,
linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:],
},
@@ -2515,7 +2329,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
{
name: "OIID Empty MAC and empty nicName",
autoGen: true,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:1],
},
@@ -2527,7 +2341,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test",
autoGen: true,
linkAddr: "\x01\x02\x03",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:2],
},
@@ -2539,7 +2353,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test2",
autoGen: true,
linkAddr: "\x01\x02\x03\x04\x05\x06",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
SecretKey: secretKey[:3],
},
@@ -2551,7 +2365,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
nicName: "test3",
autoGen: true,
linkAddr: "\x00\x00\x00\x00\x00\x00",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: nicNameFunc,
},
shouldGen: true,
@@ -2565,10 +2379,11 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
+ })},
}
e := channel.New(0, 1280, test.linkAddr)
@@ -2640,15 +2455,15 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
tests := []struct {
name string
- opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions
}{
{
name: "IID From MAC",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{},
},
{
name: "Opaque IID",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
},
@@ -2659,9 +2474,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ })},
}
e := loopback.New()
@@ -2690,12 +2506,13 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- ndpConfigs := stack.DefaultNDPConfigurations()
+ ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ })},
}
e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
@@ -2751,7 +2568,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
for _, ps := range pebs {
t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -3042,14 +2859,15 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDispatcher{},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDispatcher{},
+ })},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -3088,59 +2906,58 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
const nicID = 1
+ broadcastAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
e := loopback.New()
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
})
nicOpts := stack.NICOptions{Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
}
- allStackAddrs := s.AllAddresses()
- allNICAddrs, ok := allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
}
// Enabling the NIC should add the IPv4 broadcast address.
if err := s.EnableNIC(nicID); err != nil {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 1 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
- }
- want := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 32,
- },
- }
- if allNICAddrs[0] != want {
- t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if !containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr)
+ }
}
// Disabling the NIC should remove the IPv4 broadcast address.
if err := s.DisableNIC(nicID); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
}
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
}
}
@@ -3151,7 +2968,7 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
})
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
@@ -3188,50 +3005,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
}
}
-func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) {
const nicID = 1
- e := loopback.New()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- })
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ }{
+ {
+ name: "IPv6 All-Nodes",
+ proto: header.IPv6ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "IPv4 All-Systems",
+ proto: header.IPv4ProtocolNumber,
+ addr: header.IPv4AllSystems,
+ },
}
- // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
- // not been enabled yet.
- isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
- // The all-nodes multicast group should be joined when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
- }
+ // Should not be in the multicast group yet because the NIC has not been
+ // enabled yet.
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
- // The all-nodes multicast group should be left when the NIC is disabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // The multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // Leaving the group before disabling the NIC should not cause an error.
+ if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil {
+ t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err)
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+ })
}
}
@@ -3246,12 +3106,13 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
dadC: make(chan ndpDADEvent),
}
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPDisp: &ndpDisp,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ })},
}
e := channel.New(dadTransmits, 1280, linkAddr1)
@@ -3418,3 +3279,399 @@ func TestStackSendBufferSizeOption(t *testing.T) {
})
}
}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID1 = 1
+ )
+
+ defaultAddr := tcpip.AddressWithPrefix{
+ Address: header.IPv4Any,
+ PrefixLen: 0,
+ }
+ defaultSubnet := defaultAddr.Subnet()
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedRoute stack.Route
+ }{
+ // Broadcast to a locally attached subnet populates the broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: ipv4SubnetBcast,
+ RemoteLinkAddress: header.EthernetBroadcastAddress,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /31 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix31.Address,
+ RemoteAddress: ipv4Subnet31Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /32 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix32.Address,
+ RemoteAddress: ipv4Subnet32Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv6Addr.Address,
+ RemoteAddress: ipv6SubnetBcast,
+ NetProto: header.IPv6ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a remote subnet in the route table is send to the next-hop
+ // gateway.
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to an unknown subnet follows the default route. Note that this
+ // is essentially just routing an unknown destination IP, because w/o any
+ // subnet prefix information a subnet broadcast address is just a normal IP.
+ {
+ name: "IPv4 Broadcast to unknown subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: defaultSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
+ } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
+ t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestResolveWith(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
+
+ remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ // Should initially require resolution.
+ if !r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = false, want = true")
+ }
+
+ // Manually resolving the route should no longer require resolution.
+ r.ResolveWith("\x01")
+ if r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = true, want = false")
+ }
+}
+
+// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
+// associated address is removed should not cause a panic.
+func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
+ const (
+ nicID = 1
+ localAddr = tcpip.Address("\x01")
+ remoteAddr = tcpip.Address("\x02")
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ ep := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err)
+ }
+ // Should not panic.
+ defer r.Release()
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err)
+ }
+}
+
+func TestGetNetworkEndpoint(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ factories := make([]stack.NetworkProtocolFactory, 0, len(tests))
+ for _, test := range tests {
+ factories = append(factories, test.protoFactory)
+ }
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: factories,
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep, err := s.GetNetworkEndpoint(nicID, test.protoNum)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err)
+ }
+
+ if got := ep.NetworkProtocolNumber(); got != test.protoNum {
+ t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum)
+ }
+ })
+ }
+}
+
+func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ }
+
+ // Check that we get the right initial address and prefix length.
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+
+ // Should still get the address when the NIC is diabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..35e5b1a2e 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -155,7 +155,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.nic.ID()]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -544,9 +544,11 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return true
}
- // If the packet is a TCP packet with a non-unicast source or destination
- // address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // If the packet is a TCP packet with a unspecified source or non-unicast
+ // destination address, then do nothing further and instruct the caller to do
+ // the same. The network layer handles address validation for specified source
+ // addresses.
+ if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -626,7 +628,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.nic.ID()]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -677,10 +679,10 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isSpecified(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 73dada928..698c8609e 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -51,8 +51,8 @@ type testContext struct {
// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
linkEps := make(map[tcpip.NICID]*channel.Endpoint)
for _, linkEpID := range linkEpIDs {
@@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
})
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt)
}
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
@@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
})
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt)
}
func TestTransportDemuxerRegister(t *testing.T) {
@@ -184,8 +182,8 @@ func TestTransportDemuxerRegister(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
@@ -314,8 +312,8 @@ func TestBindToDeviceDistribution(t *testing.T) {
t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
}
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
- if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
- t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
+ if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
}
var dstAddr tcpip.Address
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 7e8b84867..62ab6d92f 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -39,7 +39,7 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
- stack *stack.Stack
+
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
@@ -53,14 +53,14 @@ func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
return &f.TransportEndpointInfo
}
-func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
+func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Abort() {
@@ -84,28 +84,28 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, nil, tcpip.ErrNoRoute
}
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen)
- hdr.Prepend(fakeTransHeaderLen)
v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: buffer.View(v).ToVectorisedView(),
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
+ Data: buffer.View(v).ToVectorisedView(),
+ })
+ _ = pkt.TransportHeader().Push(fakeTransHeaderLen)
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
}
-func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
+func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
@@ -130,11 +130,7 @@ func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.E
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return nil
- }
+func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
@@ -147,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
+ r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return tcpip.ErrNoRoute
}
@@ -155,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -169,7 +165,7 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 {
return f.uniqueID
}
-func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -184,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
return nil
}
-func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
@@ -194,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
- if err := f.stack.RegisterTransportEndpoint(
+ if err := f.proto.stack.RegisterTransportEndpoint(
a.NIC,
[]tcpip.NetworkProtocolNumber{fakeNetNumber},
fakeTransNumber,
@@ -222,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
f.proto.packetCount++
if f.acceptQueue != nil {
f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- stack: f.stack,
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -239,19 +234,19 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s
f.proto.controlCount++
}
-func (f *fakeTransportEndpoint) State() uint32 {
+func (*fakeTransportEndpoint) State() uint32 {
return 0
}
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
+func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
- return stack.IPTables{}, nil
-}
+func (*fakeTransportEndpoint) Resume(*stack.Stack) {}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+func (*fakeTransportEndpoint) Wait() {}
-func (f *fakeTransportEndpoint) Wait() {}
+func (*fakeTransportEndpoint) LastError() *tcpip.Error {
+ return nil
+}
type fakeTransportGoodOption bool
@@ -266,6 +261,8 @@ type fakeTransportProtocolOptions struct {
// fakeTransportProtocol is a transport-layer protocol descriptor. It
// aggregates the number of packets received via endpoints of this protocol.
type fakeTransportProtocol struct {
+ stack *stack.Stack
+
packetCount int
controlCount int
opts fakeTransportProtocolOptions
@@ -275,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
return fakeTransNumber
}
-func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
+func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil
}
-func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -291,26 +288,24 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
- return true
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
-func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
+func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case fakeTransportGoodOption:
- f.opts.good = bool(v)
+ case *tcpip.TCPModerateReceiveBufferOption:
+ f.opts.good = bool(*v)
return nil
- case fakeTransportInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
+func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *fakeTransportGoodOption:
- *v = fakeTransportGoodOption(f.opts.good)
+ case *tcpip.TCPModerateReceiveBufferOption:
+ *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good)
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -328,24 +323,19 @@ func (*fakeTransportProtocol) Wait() {}
// Parse implements TransportProtocol.Parse.
func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
- hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen)
- if !ok {
- return false
- }
- pkt.TransportHeader = hdr
- pkt.Data.TrimFront(fakeTransHeaderLen)
- return true
+ _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen)
+ return ok
}
-func fakeTransFactory() stack.TransportProtocol {
- return &fakeTransportProtocol{}
+func fakeTransFactory(s *stack.Stack) stack.TransportProtocol {
+ return &fakeTransportProtocol{stack: s}
}
func TestTransportReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -382,9 +372,9 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -393,9 +383,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -404,9 +394,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.packetCount != 1 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
}
@@ -415,8 +405,8 @@ func TestTransportReceive(t *testing.T) {
func TestTransportControlReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -459,9 +449,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -470,9 +460,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -481,9 +471,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.controlCount != 1 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
}
@@ -492,8 +482,8 @@ func TestTransportControlReceive(t *testing.T) {
func TestTransportSend(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -538,54 +528,29 @@ func TestTransportSend(t *testing.T) {
func TestTransportOptions(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
- // Try an unsupported transport protocol.
- if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.TransportProtocol)
- }{
- {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
- t.Helper()
- fakeTrans := p.(*fakeTransportProtocol)
- if fakeTrans.opts.good != true {
- t.Fatalf("fakeTrans.opts.good = false, want = true")
- }
- var v fakeTransportGoodOption
- if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
- }
-
- }},
- {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
- }
+ v := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err)
+ }
+ v = false
+ if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err)
+ }
+ if !v {
+ t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
}
}
func TestTransportForwarding(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
// TODO(b/123449044): Change this to a channel NIC.
ep1 := loopback.New()
@@ -636,11 +601,11 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: req.ToVectorisedView(),
- })
+ }))
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err != nil || aep == nil {
t.Fatalf("Accept failed: %v, %v", aep, err)
}
@@ -655,10 +620,11 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- if dst := p.Pkt.NetworkHeader[0]; dst != 3 {
+ nh := stack.PayloadSince(p.Pkt.NetworkHeader())
+ if dst := nh[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := p.Pkt.NetworkHeader[1]; src != 1 {
+ if src := nh[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 2be1c107a..c42bb0991 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -43,6 +43,9 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// Using header.IPv4AddressSize would cause an import cycle.
+const ipv4AddressSize = 4
+
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
@@ -192,7 +195,7 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
-// A Clock provides the current time.
+// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
@@ -203,12 +206,45 @@ type Clock interface {
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
+
+ // AfterFunc waits for the duration to elapse and then calls f in its own
+ // goroutine. It returns a Timer that can be used to cancel the call using
+ // its Stop method.
+ AfterFunc(d time.Duration, f func()) Timer
+}
+
+// Timer represents a single event. A Timer must be created with
+// Clock.AfterFunc.
+type Timer interface {
+ // Stop prevents the Timer from firing. It returns true if the call stops the
+ // timer, false if the timer has already expired or been stopped.
+ //
+ // If Stop returns false, then the timer has already expired and the function
+ // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
+ // does not wait for f to complete before returning. If the caller needs to
+ // know whether f is completed, it must coordinate with f explicitly.
+ Stop() bool
+
+ // Reset changes the timer to expire after duration d.
+ //
+ // Reset should be invoked only on stopped or expired timers. If the timer is
+ // known to have expired, Reset can be used directly. Otherwise, the caller
+ // must coordinate with the function f of Clock.AfterFunc(d, f).
+ Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
// network node. Or, in the case of unix endpoints, it may represent a path.
type Address string
+// WithPrefix returns the address with a prefix that represents a point subnet.
+func (a Address) WithPrefix() AddressWithPrefix {
+ return AddressWithPrefix{
+ Address: a,
+ PrefixLen: len(a) * 8,
+ }
+}
+
// AddressMask is a bitmask for an address.
type AddressMask string
@@ -295,6 +331,29 @@ func (s *Subnet) Broadcast() Address {
return Address(addr)
}
+// IsBroadcast returns true if the address is considered a broadcast address.
+func (s *Subnet) IsBroadcast(address Address) bool {
+ // Only IPv4 supports the notion of a broadcast address.
+ if len(address) != ipv4AddressSize {
+ return false
+ }
+
+ // Normally, we would just compare address with the subnet's broadcast
+ // address but there is an exception where a simple comparison is not
+ // correct. This exception is for /31 and /32 IPv4 subnets where all
+ // addresses are considered valid host addresses.
+ //
+ // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that
+ // both addresses in a /31 subnet "MUST be interpreted as host addresses."
+ //
+ // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32
+ // subnets. However, the same reasoning applies - if an exception is not
+ // made, then there do not exist any host addresses in a /32 subnet. RFC
+ // 4632 Section 3.1 also vaguely implies this interpretation by referring
+ // to addresses in /32 subnets as "host routes."
+ return s.Prefix() <= 30 && s.Broadcast() == address
+}
+
// Equal returns true if s equals o.
//
// Needed to use cmp.Equal on Subnet as its fields are unexported.
@@ -316,6 +375,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -488,7 +569,10 @@ type Endpoint interface {
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, *waiter.Queue, *Error)
+ //
+ // If peerAddr is not nil then it is populated with the peer address of the
+ // returned endpoint.
+ Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
@@ -505,8 +589,8 @@ type Endpoint interface {
// if waiter.EventIn is set, the endpoint is immediately readable.
Readiness(mask waiter.EventMask) waiter.EventMask
- // SetSockOpt sets a socket option. opt should be one of the *Option types.
- SetSockOpt(opt interface{}) *Error
+ // SetSockOpt sets a socket option.
+ SetSockOpt(opt SettableSocketOption) *Error
// SetSockOptBool sets a socket option, for simple cases where a value
// has the bool type.
@@ -516,9 +600,8 @@ type Endpoint interface {
// has the int type.
SetSockOptInt(opt SockOptInt, v int) *Error
- // GetSockOpt gets a socket option. opt should be a pointer to one of the
- // *Option types.
- GetSockOpt(opt interface{}) *Error
+ // GetSockOpt gets a socket option.
+ GetSockOpt(opt GettableSocketOption) *Error
// GetSockOptBool gets a socket option for simple cases where a return
// value has the bool type.
@@ -547,6 +630,31 @@ type Endpoint interface {
// SetOwner sets the task owner to the endpoint owner.
SetOwner(owner PacketOwner)
+
+ // LastError clears and returns the last error reported by the endpoint.
+ LastError() *Error
+}
+
+// LinkPacketInfo holds Link layer information for a received packet.
+//
+// +stateify savable
+type LinkPacketInfo struct {
+ // Protocol is the NetworkProtocolNumber for the packet.
+ Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
+}
+
+// PacketEndpoint are additional methods that are only implemented by Packet
+// endpoints.
+type PacketEndpoint interface {
+ // ReadPacket reads a datagram/packet from the endpoint and optionally
+ // returns the sender and additional LinkPacketInfo.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
}
// EndpointInfo is the interface implemented by each endpoint info struct.
@@ -648,6 +756,11 @@ const (
// whether an IPv6 socket is to be restricted to sending and receiving
// IPv6 packets only.
V6OnlyOption
+
+ // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
+ // endpoint that all packets being written have an IP header and the
+ // endpoint should not attach an IP header.
+ IPHdrIncludedOption
)
// SockOptInt represents socket options which values have the int type.
@@ -673,6 +786,13 @@ const (
// TCP_MAXSEG option.
MaxSegOption
+ // MTUDiscoverOption is used to set/get the path MTU discovery setting.
+ //
+ // NOTE: Setting this option to any other value than PMTUDiscoveryDont
+ // is not supported and will fail as such, and getting this option will
+ // always return PMTUDiscoveryDont.
+ MTUDiscoverOption
+
// MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control
// the default TTL value for multicast messages. The default is 1.
MulticastTTLOption
@@ -714,14 +834,152 @@ const (
TCPWindowClampOption
)
-// ErrorOption is used in GetSockOpt to specify that the last error reported by
-// the endpoint should be cleared and returned.
-type ErrorOption struct{}
+const (
+ // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use
+ // per-route settings.
+ PMTUDiscoveryWant int = iota
+
+ // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable
+ // path MTU discovery.
+ PMTUDiscoveryDont
+
+ // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do
+ // path MTU discovery.
+ PMTUDiscoveryDo
+
+ // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF
+ // but ignore path MTU.
+ PMTUDiscoveryProbe
+)
+
+// GettableNetworkProtocolOption is a marker interface for network protocol
+// options that may be queried.
+type GettableNetworkProtocolOption interface {
+ isGettableNetworkProtocolOption()
+}
+
+// SettableNetworkProtocolOption is a marker interface for network protocol
+// options that may be set.
+type SettableNetworkProtocolOption interface {
+ isSettableNetworkProtocolOption()
+}
+
+// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify
+// a default TTL.
+type DefaultTTLOption uint8
+
+func (*DefaultTTLOption) isGettableNetworkProtocolOption() {}
+
+func (*DefaultTTLOption) isSettableNetworkProtocolOption() {}
+
+// GettableTransportProtocolOption is a marker interface for transport protocol
+// options that may be queried.
+type GettableTransportProtocolOption interface {
+ isGettableTransportProtocolOption()
+}
+
+// SettableTransportProtocolOption is a marker interface for transport protocol
+// options that may be set.
+type SettableTransportProtocolOption interface {
+ isSettableTransportProtocolOption()
+}
+
+// TCPSACKEnabled the SACK option for TCP.
+//
+// See: https://tools.ietf.org/html/rfc2018.
+type TCPSACKEnabled bool
+
+func (*TCPSACKEnabled) isGettableTransportProtocolOption() {}
+
+func (*TCPSACKEnabled) isSettableTransportProtocolOption() {}
+
+// TCPRecovery is the loss deteoction algorithm used by TCP.
+type TCPRecovery int32
+
+func (*TCPRecovery) isGettableTransportProtocolOption() {}
+
+func (*TCPRecovery) isSettableTransportProtocolOption() {}
+
+const (
+ // TCPRACKLossDetection indicates RACK is used for loss detection and
+ // recovery.
+ TCPRACKLossDetection TCPRecovery = 1 << iota
+
+ // TCPRACKStaticReoWnd indicates the reordering window should not be
+ // adjusted when DSACK is received.
+ TCPRACKStaticReoWnd
+
+ // TCPRACKNoDupTh indicates RACK should not consider the classic three
+ // duplicate acknowledgements rule to mark the segments as lost. This
+ // is used when reordering is not detected.
+ TCPRACKNoDupTh
+)
+
+// TCPDelayEnabled enables/disables Nagle's algorithm in TCP.
+type TCPDelayEnabled bool
+
+func (*TCPDelayEnabled) isGettableTransportProtocolOption() {}
+
+func (*TCPDelayEnabled) isSettableTransportProtocolOption() {}
+
+// TCPSendBufferSizeRangeOption is the send buffer size range for TCP.
+type TCPSendBufferSizeRangeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+func (*TCPSendBufferSizeRangeOption) isGettableTransportProtocolOption() {}
+
+func (*TCPSendBufferSizeRangeOption) isSettableTransportProtocolOption() {}
+
+// TCPReceiveBufferSizeRangeOption is the receive buffer size range for TCP.
+type TCPReceiveBufferSizeRangeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+func (*TCPReceiveBufferSizeRangeOption) isGettableTransportProtocolOption() {}
+
+func (*TCPReceiveBufferSizeRangeOption) isSettableTransportProtocolOption() {}
+
+// TCPAvailableCongestionControlOption is the supported congestion control
+// algorithms for TCP
+type TCPAvailableCongestionControlOption string
+
+func (*TCPAvailableCongestionControlOption) isGettableTransportProtocolOption() {}
+
+func (*TCPAvailableCongestionControlOption) isSettableTransportProtocolOption() {}
+
+// TCPModerateReceiveBufferOption enables/disables receive buffer moderation
+// for TCP.
+type TCPModerateReceiveBufferOption bool
+
+func (*TCPModerateReceiveBufferOption) isGettableTransportProtocolOption() {}
+
+func (*TCPModerateReceiveBufferOption) isSettableTransportProtocolOption() {}
+
+// GettableSocketOption is a marker interface for socket options that may be
+// queried.
+type GettableSocketOption interface {
+ isGettableSocketOption()
+}
+
+// SettableSocketOption is a marker interface for socket options that may be
+// configured.
+type SettableSocketOption interface {
+ isSettableSocketOption()
+}
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
type BindToDeviceOption NICID
+func (*BindToDeviceOption) isGettableSocketOption() {}
+
+func (*BindToDeviceOption) isSettableSocketOption() {}
+
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
@@ -730,68 +988,143 @@ type TCPInfoOption struct {
RTTVar time.Duration
}
+func (*TCPInfoOption) isGettableSocketOption() {}
+
// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
// connection must remain idle before the first TCP keepalive packet is sent.
// Once this time is reached, KeepaliveIntervalOption is used instead.
type KeepaliveIdleOption time.Duration
+func (*KeepaliveIdleOption) isGettableSocketOption() {}
+
+func (*KeepaliveIdleOption) isSettableSocketOption() {}
+
// KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the
// interval between sending TCP keepalive packets.
type KeepaliveIntervalOption time.Duration
+func (*KeepaliveIntervalOption) isGettableSocketOption() {}
+
+func (*KeepaliveIntervalOption) isSettableSocketOption() {}
+
// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user
// specified timeout for a given TCP connection.
// See: RFC5482 for details.
type TCPUserTimeoutOption time.Duration
+func (*TCPUserTimeoutOption) isGettableSocketOption() {}
+
+func (*TCPUserTimeoutOption) isSettableSocketOption() {}
+
// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
// the current congestion control algorithm.
type CongestionControlOption string
-// AvailableCongestionControlOption is used to query the supported congestion
-// control algorithms.
-type AvailableCongestionControlOption string
+func (*CongestionControlOption) isGettableSocketOption() {}
+
+func (*CongestionControlOption) isSettableSocketOption() {}
-// buffer moderation.
-type ModerateReceiveBufferOption bool
+func (*CongestionControlOption) isGettableTransportProtocolOption() {}
+
+func (*CongestionControlOption) isSettableTransportProtocolOption() {}
// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state
// before being marked closed.
type TCPLingerTimeoutOption time.Duration
+func (*TCPLingerTimeoutOption) isGettableSocketOption() {}
+
+func (*TCPLingerTimeoutOption) isSettableSocketOption() {}
+
+func (*TCPLingerTimeoutOption) isGettableTransportProtocolOption() {}
+
+func (*TCPLingerTimeoutOption) isSettableTransportProtocolOption() {}
+
// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum duration for which a socket lingers in the TIME_WAIT state
// before being marked closed.
type TCPTimeWaitTimeoutOption time.Duration
+func (*TCPTimeWaitTimeoutOption) isGettableSocketOption() {}
+
+func (*TCPTimeWaitTimeoutOption) isSettableSocketOption() {}
+
+func (*TCPTimeWaitTimeoutOption) isGettableTransportProtocolOption() {}
+
+func (*TCPTimeWaitTimeoutOption) isSettableTransportProtocolOption() {}
+
// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
// accept to return a completed connection only when there is data to be
// read. This usually means the listening socket will drop the final ACK
// for a handshake till the specified timeout until a segment with data arrives.
type TCPDeferAcceptOption time.Duration
+func (*TCPDeferAcceptOption) isGettableSocketOption() {}
+
+func (*TCPDeferAcceptOption) isSettableSocketOption() {}
+
// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
// default MinRTO used by the Stack.
type TCPMinRTOOption time.Duration
+func (*TCPMinRTOOption) isGettableSocketOption() {}
+
+func (*TCPMinRTOOption) isSettableSocketOption() {}
+
+func (*TCPMinRTOOption) isGettableTransportProtocolOption() {}
+
+func (*TCPMinRTOOption) isSettableTransportProtocolOption() {}
+
// TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
// default MaxRTO used by the Stack.
type TCPMaxRTOOption time.Duration
+func (*TCPMaxRTOOption) isGettableSocketOption() {}
+
+func (*TCPMaxRTOOption) isSettableSocketOption() {}
+
+func (*TCPMaxRTOOption) isGettableTransportProtocolOption() {}
+
+func (*TCPMaxRTOOption) isSettableTransportProtocolOption() {}
+
// TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the
// maximum number of retransmits after which we time out the connection.
type TCPMaxRetriesOption uint64
+func (*TCPMaxRetriesOption) isGettableSocketOption() {}
+
+func (*TCPMaxRetriesOption) isSettableSocketOption() {}
+
+func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {}
+
+func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {}
+
// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
// the number of endpoints that can be in SYN-RCVD state before the stack
// switches to using SYN cookies.
type TCPSynRcvdCountThresholdOption uint64
+func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {}
+
+func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {}
+
+func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {}
+
+func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {}
+
// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide
// default for number of times SYN is retransmitted before aborting a connect.
type TCPSynRetriesOption uint8
+func (*TCPSynRetriesOption) isGettableSocketOption() {}
+
+func (*TCPSynRetriesOption) isSettableSocketOption() {}
+
+func (*TCPSynRetriesOption) isGettableTransportProtocolOption() {}
+
+func (*TCPSynRetriesOption) isSettableTransportProtocolOption() {}
+
// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
// default interface for multicast.
type MulticastInterfaceOption struct {
@@ -799,33 +1132,89 @@ type MulticastInterfaceOption struct {
InterfaceAddr Address
}
-// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
-// AddMembershipOption and RemoveMembershipOption.
+func (*MulticastInterfaceOption) isGettableSocketOption() {}
+
+func (*MulticastInterfaceOption) isSettableSocketOption() {}
+
+// MembershipOption is used to identify a multicast membership on an interface.
type MembershipOption struct {
NIC NICID
InterfaceAddr Address
MulticastAddr Address
}
-// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
-// group identified by the given multicast address, on the interface matching
-// the given interface address.
+// AddMembershipOption identifies a multicast group to join on some interface.
type AddMembershipOption MembershipOption
-// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
-// group identified by the given multicast address, on the interface matching
-// the given interface address.
+func (*AddMembershipOption) isSettableSocketOption() {}
+
+// RemoveMembershipOption identifies a multicast group to leave on some
+// interface.
type RemoveMembershipOption MembershipOption
+func (*RemoveMembershipOption) isSettableSocketOption() {}
+
// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether
// TCP out-of-band data is delivered along with the normal in-band data.
type OutOfBandInlineOption int
-// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify
-// a default TTL.
-type DefaultTTLOption uint8
+func (*OutOfBandInlineOption) isGettableSocketOption() {}
+
+func (*OutOfBandInlineOption) isSettableSocketOption() {}
+
+// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
+// classic BPF filter on a given endpoint.
+type SocketDetachFilterOption int
+
+func (*SocketDetachFilterOption) isSettableSocketOption() {}
+// OriginalDestinationOption is used to get the original destination address
+// and port of a redirected packet.
+type OriginalDestinationOption FullAddress
+
+func (*OriginalDestinationOption) isGettableSocketOption() {}
+
+// TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to
+// specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for
+// new connections when it is safe from protocol viewpoint.
+type TCPTimeWaitReuseOption uint8
+
+func (*TCPTimeWaitReuseOption) isGettableSocketOption() {}
+
+func (*TCPTimeWaitReuseOption) isSettableSocketOption() {}
+
+func (*TCPTimeWaitReuseOption) isGettableTransportProtocolOption() {}
+
+func (*TCPTimeWaitReuseOption) isSettableTransportProtocolOption() {}
+
+const (
+ // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot
+ // be reused for new connections.
+ TCPTimeWaitReuseDisabled TCPTimeWaitReuseOption = iota
+
+ // TCPTimeWaitReuseGlobal indicates reuse of port bound by endponts in TIME-WAIT can
+ // be reused for new connections irrespective of the src/dest addresses.
+ TCPTimeWaitReuseGlobal
+
+ // TCPTimeWaitReuseLoopbackOnly indicates reuse of port bound by endpoint in TIME-WAIT can
+ // only be reused if the connection was a connection over loopback. i.e src/dest adddresses
+ // are loopback addresses.
+ TCPTimeWaitReuseLoopbackOnly
+)
+
+// LingerOption is used by SetSockOpt/GetSockOpt to set/get the
+// duration for which a socket lingers before returning from Close.
//
+// +stateify savable
+type LingerOption struct {
+ Enabled bool
+ Timeout time.Duration
+}
+
+func (*LingerOption) isGettableSocketOption() {}
+
+func (*LingerOption) isSettableSocketOption() {}
+
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
@@ -836,7 +1225,7 @@ type IPPacketInfo struct {
// LocalAddr is the local address.
LocalAddr Address
- // DestinationAddr is the destination address.
+ // DestinationAddr is the destination address found in the IP header.
DestinationAddr Address
}
@@ -868,7 +1257,10 @@ func (r Route) String() string {
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
-// NetworkProtocolNumber is the number of a network protocol.
+// NetworkProtocolNumber is the EtherType of a network protocol in an Ethernet
+// frame.
+//
+// See: https://www.iana.org/assignments/ieee-802-numbers/ieee-802-numbers.xhtml
type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
@@ -1031,6 +1423,10 @@ type ICMPv6ReceivedPacketStats struct {
// Invalid is the total number of ICMPv6 packets received that the
// transport layer could not parse.
Invalid *StatCounter
+
+ // RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets
+ // dropped due to being router-specific packets.
+ RouterOnlyPacketsDroppedByHost *StatCounter
}
// ICMPStats collects ICMP-specific stats (both v4 and v6).
@@ -1086,6 +1482,18 @@ type IPStats struct {
// MalformedFragmentsReceived is the total number of IP Fragments that were
// dropped due to the fragment failing validation checks.
MalformedFragmentsReceived *StatCounter
+
+ // IPTablesPreroutingDropped is the total number of IP packets dropped
+ // in the Prerouting chain.
+ IPTablesPreroutingDropped *StatCounter
+
+ // IPTablesInputDropped is the total number of IP packets dropped in
+ // the Input chain.
+ IPTablesInputDropped *StatCounter
+
+ // IPTablesOutputDropped is the total number of IP packets dropped in
+ // the Output chain.
+ IPTablesOutputDropped *StatCounter
}
// TCPStats collects TCP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
new file mode 100644
index 000000000..06c7a3cd3
--- /dev/null
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "integration_test",
+ size = "small",
+ srcs = [
+ "loopback_test.go",
+ "multicast_broadcast_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
new file mode 100644
index 000000000..e8caf09ba
--- /dev/null
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -0,0 +1,314 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
+
+type ndpDispatcher struct{}
+
+func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
+ return false
+}
+
+func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
+
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return true
+}
+
+func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
+
+func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {}
+
+func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {}
+
+func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {}
+
+func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {}
+
+// TestInitialLoopbackAddresses tests that the loopback interface does not
+// auto-generate a link-local address when it is brought up.
+func TestInitialLoopbackAddresses(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPDisp: &ndpDispatcher{},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
+ t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
+ return ""
+ },
+ },
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ nicsInfo := s.NICInfo()
+ if nicInfo, ok := nicsInfo[nicID]; !ok {
+ t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo)
+ } else if got := len(nicInfo.ProtocolAddresses); got != 0 {
+ t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses)
+ }
+}
+
+// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers
+// itself bound to all addresses in the subnet of an assigned address.
+func TestLoopbackAcceptAllInSubnet(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 80
+ )
+
+ data := []byte{1, 2, 3, 4}
+
+ ipv4ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ ipv4Bytes := []byte(ipv4Addr.Address)
+ ipv4Bytes[len(ipv4Bytes)-1]++
+ otherIPv4Address := tcpip.Address(ipv4Bytes)
+
+ ipv6ProtocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ }
+ ipv6Bytes := []byte(ipv6Addr.Address)
+ ipv6Bytes[len(ipv6Bytes)-1]++
+ otherIPv6Address := tcpip.Address(ipv6Bytes)
+
+ tests := []struct {
+ name string
+ addAddress tcpip.ProtocolAddress
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectRx bool
+ }{
+ {
+ name: "IPv4 bind to wildcard and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: otherIPv4Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 bind to wildcard send to other address",
+ addAddress: ipv4ProtocolAddress,
+ dstAddr: remoteIPv4Addr,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 bind to other subnet-local address and send to assigned address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: ipv4Addr.Address,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 bind and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: otherIPv4Address,
+ dstAddr: otherIPv4Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 bind to assigned address and send to other subnet-local address",
+ addAddress: ipv4ProtocolAddress,
+ bindAddr: ipv4Addr.Address,
+ dstAddr: otherIPv4Address,
+ expectRx: false,
+ },
+
+ {
+ name: "IPv6 bind and send to assigned address",
+ addAddress: ipv6ProtocolAddress,
+ bindAddr: ipv6Addr.Address,
+ dstAddr: ipv6Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv6 bind to wildcard and send to other subnet-local address",
+ addAddress: ipv6ProtocolAddress,
+ dstAddr: otherIPv6Address,
+ expectRx: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ wq := waiter.Queue{}
+ rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer rep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
+ if err := rep.Bind(bindAddr); err != nil {
+ t.Fatalf("rep.Bind(%+v): %s", bindAddr, err)
+ }
+
+ sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
+ }
+ defer sep.Close()
+
+ wopts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.dstAddr,
+ Port: localPort,
+ },
+ }
+ n, _, err := sep.Write(tcpip.SlicePayload(data), wopts)
+ if err != nil {
+ t.Fatalf("sep.Write(_, _): %s", err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want)
+ }
+
+ if gotPayload, _, err := rep.Read(nil); test.expectRx {
+ if err != nil {
+ t.Fatalf("reep.Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ }
+ })
+ }
+}
+
+// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address
+// in a loopback interface's associated subnet is bound to the permanently bound
+// address.
+func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
+ const nicID = 1
+
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ }
+ addrBytes := []byte(ipv4Addr.Address)
+ addrBytes[len(addrBytes)-1]++
+ otherAddr := tcpip.Address(addrBytes)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ params := stack.NetworkHeaderParams{
+ Protocol: 111,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err)
+ }
+
+ // Removing the address should make the endpoint invalid.
+ if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
+ }
+ if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
new file mode 100644
index 000000000..4f2ca7f54
--- /dev/null
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -0,0 +1,556 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ defaultMTU = 1280
+ ttl = 255
+)
+
+var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ ipv4Subnet = ipv4Addr.Subnet()
+ ipv4SubnetBcast = ipv4Subnet.Broadcast()
+
+ ipv6Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("200a::1").To16()),
+ PrefixLen: 64,
+ }
+ ipv6Subnet = ipv6Addr.Subnet()
+ ipv6SubnetBcast = ipv6Subnet.Broadcast()
+
+ // Remote addrs.
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ remoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16())
+)
+
+// TestPingMulticastBroadcast tests that responding to an Echo Request destined
+// to a multicast or broadcast address uses a unicast source address for the
+// reply.
+func TestPingMulticastBroadcast(t *testing.T) {
+ const nicID = 1
+
+ rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ttl,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ dstAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast",
+ dstAddr: ipv4Addr.Address,
+ },
+ {
+ name: "IPv4 directed broadcast",
+ dstAddr: ipv4SubnetBcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ dstAddr: header.IPv4Broadcast,
+ },
+ {
+ name: "IPv4 all-systems multicast",
+ dstAddr: header.IPv4AllSystems,
+ },
+ {
+ name: "IPv6 unicast",
+ dstAddr: ipv6Addr.Address,
+ },
+ {
+ name: "IPv6 all-nodes multicast",
+ dstAddr: header.IPv6AllNodesMulticastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ })
+ // We only expect a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ }
+ ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ var rxICMP func(*channel.Endpoint, tcpip.Address)
+ var expectedSrc tcpip.Address
+ var expectedDst tcpip.Address
+ var protoNum tcpip.NetworkProtocolNumber
+ switch l := len(test.dstAddr); l {
+ case header.IPv4AddressSize:
+ rxICMP = rxIPv4ICMP
+ expectedSrc = ipv4Addr.Address
+ expectedDst = remoteIPv4Addr
+ protoNum = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ rxICMP = rxIPv6ICMP
+ expectedSrc = ipv6Addr.Address
+ expectedDst = remoteIPv6Addr
+ protoNum = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ rxICMP(e, test.dstAddr)
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected ICMP response")
+ }
+
+ if pkt.Route.LocalAddress != expectedSrc {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc)
+ }
+ if pkt.Route.RemoteAddress != expectedDst {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
+ }
+
+ src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if src != expectedSrc {
+ t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
+ }
+ if dst != expectedDst {
+ t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst)
+ }
+ })
+ }
+
+}
+
+// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
+// multicast or broadcast address.
+func TestIncomingMulticastAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ remotePort = 5555
+ localPort = 80
+ )
+
+ data := []byte{1, 2, 3, 4}
+
+ rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ totalLen := header.IPv4MinimumSize + payloadLen
+ hdr := buffer.NewPrependable(totalLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(udp.ProtocolNumber),
+ TTL: ttl,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLen),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectRx bool
+ }{
+ {
+ name: "IPv4 unicast binding to unicast",
+ bindAddr: ipv4Addr.Address,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 unicast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4Addr.Address,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 unicast binding to wildcard",
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 directed broadcast binding to subnet broadcast",
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 directed broadcast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 directed broadcast binding to wildcard",
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 broadcast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 broadcast binding to subnet broadcast",
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 broadcast binding to wildcard",
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 all-systems multicast binding to all-systems multicast",
+ bindAddr: header.IPv4AllSystems,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 all-systems multicast binding to wildcard",
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 all-systems multicast binding to unicast",
+ bindAddr: ipv4Addr.Address,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: false,
+ },
+
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 unicast binding to wildcard",
+ dstAddr: ipv6Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv6 broadcast-like address binding to wildcard",
+ dstAddr: ipv6SubnetBcast,
+ expectRx: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ }
+ ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ }
+
+ var netproto tcpip.NetworkProtocolNumber
+ var rxUDP func(*channel.Endpoint, tcpip.Address)
+ switch l := len(test.dstAddr); l {
+ case header.IPv4AddressSize:
+ netproto = header.IPv4ProtocolNumber
+ rxUDP = rxIPv4UDP
+ case header.IPv6AddressSize:
+ netproto = header.IPv6ProtocolNumber
+ rxUDP = rxIPv6UDP
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", bindAddr, err)
+ }
+
+ rxUDP(e, test.dstAddr)
+ if gotPayload, _, err := ep.Read(nil); test.expectRx {
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ }
+ })
+ }
+}
+
+// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
+// interested endpoints.
+func TestReuseAddrAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 9000
+ loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ )
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+
+ tests := []struct {
+ name string
+ broadcastAddr tcpip.Address
+ }{
+ {
+ name: "Subnet directed broadcast",
+ broadcastAddr: loopbackBroadcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ broadcastAddr: header.IPv4Broadcast,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x7f\x00\x00\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ // We use the empty subnet instead of just the loopback subnet so we
+ // also have a route to the IPv4 Broadcast address.
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // We create endpoints that bind to both the wildcard address and the
+ // broadcast address to make sure both of these types of "broadcast
+ // interested" endpoints receive broadcast packets.
+ wq := waiter.Queue{}
+ var eps []tcpip.Endpoint
+ for _, bindWildcard := range []bool{false, true} {
+ // Create multiple endpoints for each type of "broadcast interested"
+ // endpoint so we can test that all endpoints receive the broadcast
+ // packet.
+ for i := 0; i < 2; i++ {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
+ }
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if bindWildcard {
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ } else {
+ bindAddr.Addr = test.broadcastAddr
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ }
+
+ eps = append(eps, ep)
+ }
+ }
+
+ for i, wep := range eps {
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.broadcastAddr,
+ Port: localPort,
+ },
+ }
+ if n, _, err := wep.Write(data, writeOpts); err != nil {
+ t.Fatalf("eps[%d].Write(_, _): %s", i, err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
+ }
+
+ for j, rep := range eps {
+ if gotPayload, _, err := rep.Read(nil); err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
+ } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 7f172f978..606363567 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -13,14 +13,14 @@
// limitations under the License.
// +build go1.9
-// +build !go1.16
+// +build !go1.17
// Check go:linkname function signatures when updating Go version.
package tcpip
import (
- _ "time" // Used with go:linkname.
+ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
@@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
index 59f3b391f..f1dd7c310 100644
--- a/pkg/tcpip/timer.go
+++ b/pkg/tcpip/timer.go
@@ -15,54 +15,54 @@
package tcpip
import (
- "sync"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
-// cancellableTimerInstance is a specific instance of CancellableTimer.
+// jobInstance is a specific instance of Job.
//
-// Different instances are created each time CancellableTimer is Reset so each
-// timer has its own earlyReturn signal. This is to address a bug when a
-// CancellableTimer is stopped and reset in quick succession resulting in a
-// timer instance's earlyReturn signal being affected or seen by another timer
-// instance.
+// Different instances are created each time Job is scheduled so each timer has
+// its own earlyReturn signal. This is to address a bug when a Job is stopped
+// and reset in quick succession resulting in a timer instance's earlyReturn
+// signal being affected or seen by another timer instance.
//
// Consider the following sceneario where timer instances share a common
// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
// (B), third (C), and fourth (D) instance of the timer firing, respectively):
// T1: Obtain L
-// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T1: Create a new Job w/ lock L (create instance A)
// T2: instance A fires, blocked trying to obtain L.
// T1: Attempt to stop instance A (set earlyReturn = true)
-// T1: Reset timer (create instance B)
+// T1: Schedule timer (create instance B)
// T3: instance B fires, blocked trying to obtain L.
// T1: Attempt to stop instance B (set earlyReturn = true)
-// T1: Reset timer (create instance C)
+// T1: Schedule timer (create instance C)
// T4: instance C fires, blocked trying to obtain L.
// T1: Attempt to stop instance C (set earlyReturn = true)
-// T1: Reset timer (create instance D)
+// T1: Schedule timer (create instance D)
// T5: instance D fires, blocked trying to obtain L.
// T1: Release L
//
-// Now that T1 has released L, any of the 4 timer instances can take L and check
-// earlyReturn. If the timers simply check earlyReturn and then do nothing
-// further, then instance D will never early return even though it was not
-// requested to stop. If the timers reset earlyReturn before early returning,
-// then all but one of the timers will do work when only one was expected to.
-// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// Now that T1 has released L, any of the 4 timer instances can take L and
+// check earlyReturn. If the timers simply check earlyReturn and then do
+// nothing further, then instance D will never early return even though it was
+// not requested to stop. If the timers reset earlyReturn before early
+// returning, then all but one of the timers will do work when only one was
+// expected to. If Job resets earlyReturn when resetting, then all the timers
// will fire (again, when only one was expected to).
//
// To address the above concerns the simplest solution was to give each timer
// its own earlyReturn signal.
-type cancellableTimerInstance struct {
- timer *time.Timer
+type jobInstance struct {
+ timer Timer
// Used to inform the timer to early return when it gets stopped while the
// lock the timer tries to obtain when fired is held (T1 is a goroutine that
// tries to cancel the timer and T2 is the goroutine that handles the timer
// firing):
- // T1: Obtain the lock, then call StopLocked()
+ // T1: Obtain the lock, then call Cancel()
// T2: timer fires, and gets blocked on obtaining the lock
// T1: Releases lock
// T2: Obtains lock does unintended work
@@ -73,27 +73,33 @@ type cancellableTimerInstance struct {
earlyReturn *bool
}
-// stop stops the timer instance t from firing if it hasn't fired already. If it
+// stop stops the job instance j from firing if it hasn't fired already. If it
// has fired and is blocked at obtaining the lock, earlyReturn will be set to
// true so that it will early return when it obtains the lock.
-func (t *cancellableTimerInstance) stop() {
- if t.timer != nil {
- t.timer.Stop()
- *t.earlyReturn = true
+func (j *jobInstance) stop() {
+ if j.timer != nil {
+ j.timer.Stop()
+ *j.earlyReturn = true
}
}
-// CancellableTimer is a timer that does some work and can be safely cancelled
-// when it fires at the same time some "related work" is being done.
+// Job represents some work that can be scheduled for execution. The work can
+// be safely cancelled when it fires at the same time some "related work" is
+// being done.
//
// The term "related work" is defined as some work that needs to be done while
// holding some lock that the timer must also hold while doing some work.
//
-// Note, it is not safe to copy a CancellableTimer as its timer instance creates
-// a closure over the address of the CancellableTimer.
-type CancellableTimer struct {
+// Note, it is not safe to copy a Job as its timer instance creates
+// a closure over the address of the Job.
+type Job struct {
+ _ sync.NoCopy
+
+ // The clock used to schedule the backing timer
+ clock Clock
+
// The active instance of a cancellable timer.
- instance cancellableTimerInstance
+ instance jobInstance
// locker is the lock taken by the timer immediately after it fires and must
// be held when attempting to stop the timer.
@@ -110,75 +116,91 @@ type CancellableTimer struct {
fn func()
}
-// StopLocked prevents the Timer from firing if it has not fired already.
+// Cancel prevents the Job from executing if it has not executed already.
//
-// If the timer is blocked on obtaining the t.locker lock when StopLocked is
-// called, it will early return instead of calling t.fn.
+// Cancel requires appropriate locking to be in place for any resources managed
+// by the Job. If the Job is blocked on obtaining the lock when Cancel is
+// called, it will early return.
//
// Note, t will be modified.
//
-// t.locker MUST be locked.
-func (t *CancellableTimer) StopLocked() {
- t.instance.stop()
+// j.locker MUST be locked.
+func (j *Job) Cancel() {
+ j.instance.stop()
// Nothing to do with the stopped instance anymore.
- t.instance = cancellableTimerInstance{}
+ j.instance = jobInstance{}
}
-// Reset changes the timer to expire after duration d.
+// Schedule schedules the Job for execution after duration d. This can be
+// called on cancelled or completed Jobs to schedule them again.
//
-// Note, t will be modified.
+// Schedule should be invoked only on unscheduled, cancelled, or completed
+// Jobs. To be safe, callers should always call Cancel before calling Schedule.
//
-// Reset should only be called on stopped or expired timers. To be safe, callers
-// should always call StopLocked before calling Reset.
-func (t *CancellableTimer) Reset(d time.Duration) {
+// Note, j will be modified.
+func (j *Job) Schedule(d time.Duration) {
// Create a new instance.
earlyReturn := false
// Capture the locker so that updating the timer does not cause a data race
// when a timer fires and tries to obtain the lock (read the timer's locker).
- locker := t.locker
- t.instance = cancellableTimerInstance{
- timer: time.AfterFunc(d, func() {
+ locker := j.locker
+ j.instance = jobInstance{
+ timer: j.clock.AfterFunc(d, func() {
locker.Lock()
defer locker.Unlock()
if earlyReturn {
// If we reach this point, it means that the timer fired while another
- // goroutine called StopLocked while it had the lock. Simply return
- // here and do nothing further.
+ // goroutine called Cancel while it had the lock. Simply return here
+ // and do nothing further.
earlyReturn = false
return
}
- t.fn()
+ j.fn()
}),
earlyReturn: &earlyReturn,
}
}
-// Lock is a no-op used by the copylocks checker from go vet.
-//
-// See CancellableTimer for details about why it shouldn't be copied.
-//
-// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
-// details about the copylocks checker.
-func (*CancellableTimer) Lock() {}
-
-// Unlock is a no-op used by the copylocks checker from go vet.
-//
-// See CancellableTimer for details about why it shouldn't be copied.
-//
-// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
-// details about the copylocks checker.
-func (*CancellableTimer) Unlock() {}
-
-// NewCancellableTimer returns an unscheduled CancellableTimer with the given
-// locker and fn.
-//
-// fn MUST NOT attempt to lock locker.
-//
-// Callers must call Reset to schedule the timer to fire.
-func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer {
- return &CancellableTimer{locker: locker, fn: fn}
+// NewJob returns a new Job that can be used to schedule f to run in its own
+// gorountine. l will be locked before calling f then unlocked after f returns.
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// message = "bar"
+// mu.Unlock()
+//
+// // Output: bar
+//
+// f MUST NOT attempt to lock l.
+//
+// l MUST be locked prior to calling the returned job's Cancel().
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// job.Cancel()
+// mu.Unlock()
+func NewJob(c Clock, l sync.Locker, f func()) *Job {
+ return &Job{
+ clock: c,
+ locker: l,
+ fn: f,
+ }
}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index b4940e397..a82384c49 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -28,8 +28,8 @@ const (
longDuration = 1 * time.Second
)
-func TestCancellableTimerReassignment(t *testing.T) {
- var timer tcpip.CancellableTimer
+func TestJobReschedule(t *testing.T) {
+ var clock tcpip.StdClock
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- timer = *tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
wg.Done()
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
lock.Unlock()
}()
}
wg.Wait()
}
-func TestCancellableTimerFire(t *testing.T) {
+func TestJobExecution(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
ch <- struct{}{}
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(middleDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(middleDuration)
lock.Lock()
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
}
}
-func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
// Wait for timer to fire if it wasn't correctly stopped.
@@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
case <-time.After(middleDuration):
}
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
}
}
-func TestCancellableTimerImmediatelyStop(t *testing.T) {
+func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
}
@@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) {
}
}
-func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
for i := 0; i < 10; i++ {
- timer.Reset(middleDuration)
+ job.Schedule(middleDuration)
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration * 2)
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
}
@@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration)
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
@@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
}
}
-func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 8ce294002..41eb0ca44 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -74,6 +74,8 @@ type endpoint struct {
route stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -343,7 +345,16 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
+ }
return nil
}
@@ -411,9 +422,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
return nil
default:
@@ -426,9 +440,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ })
+ pkt.Owner = owner
- icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
copy(icmpv4, data)
// Set the ident to the user-specified port. Sequence number should
// already be set by the user.
@@ -443,15 +461,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+ pkt.Data = data.ToVectorisedView()
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: data.ToVectorisedView(),
- TransportHeader: buffer.View(icmpv4),
- Owner: owner,
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
@@ -459,9 +474,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ })
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(icmpv6, data)
// Set the ident. Sequence number is provided by the user.
icmpv6.SetIdent(ident)
@@ -473,15 +491,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
dataVV := data.ToVectorisedView()
icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
+ pkt.Data = dataVV
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: dataVV,
- TransportHeader: buffer.View(icmpv6),
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
// checkV4MappedLocked determines the effective network protocol and converts
@@ -600,7 +615,7 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -744,15 +759,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
- if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
@@ -786,12 +805,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
},
}
- packet.data = pkt.Data
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader().View().ToVectorisedView()
+ packet.data.Append(pkt.Data)
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
@@ -827,3 +848,8 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+
+// LastError implements tcpip.Endpoint.LastError.
+func (*endpoint) LastError() *tcpip.Error {
+ return nil
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 74ef6541e..87d510f96 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -13,12 +13,7 @@
// limitations under the License.
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
-// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and activated on the stack by passing
-// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
-// protocols when calling stack.New(). Then endpoints can be created by passing
-// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
-// when calling Stack.NewEndpoint().
+// protocols for use in ping.
package icmp
import (
@@ -42,6 +37,8 @@ const (
// protocol implements stack.TransportProtocol.
type protocol struct {
+ stack *stack.Stack
+
number tcpip.TransportProtocolNumber
}
@@ -62,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue)
+ return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+ return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
@@ -104,17 +101,17 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
- return true
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -135,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol4 returns an ICMPv4 transport protocol.
-func NewProtocol4() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
+func NewProtocol4(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber4}
}
// NewProtocol6 returns an ICMPv6 transport protocol.
-func NewProtocol6() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
+func NewProtocol6(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index a8f8454dd..072601d2d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -45,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -79,6 +82,13 @@ type endpoint struct {
closed bool
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
+ boundNIC tcpip.NICID
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -146,8 +156,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -172,16 +182,25 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
-func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- // TODO(b/129292371): Implement.
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
+func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -193,25 +212,25 @@ func (*endpoint) Disconnect() *tcpip.Error {
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
// connected, and this function always returnes tcpip.ErrNotSupported.
-func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
// with Shutdown, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
// Listen, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
// Accept, and this function always returns tcpip.ErrNotSupported.
-func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -229,12 +248,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound {
- return tcpip.ErrAlreadyBound
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
}
// Unregister endpoint with all the nics.
ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
@@ -242,17 +263,18 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
ep.bound = true
+ ep.boundNIC = addr.NIC
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -277,8 +299,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
-func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ ep.linger = *v
+ ep.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
@@ -330,13 +364,31 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
}
+func (ep *endpoint) LastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ *o = ep.linger
+ ep.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrNotSupported
+ }
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return false, tcpip.ErrNotSupported
}
@@ -393,48 +445,73 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(b/129292371): Return network protocol.
- if len(pkt.LinkHeader) > 0 {
+ // TODO(gvisor.dev/issue/173): Return network protocol.
+ if !pkt.LinkHeader().View().IsEmpty() {
// Get info directly from the ethernet header.
- hdr := header.Ethernet(pkt.LinkHeader)
+ hdr := header.Ethernet(pkt.LinkHeader().View())
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header.
+ var combinedVV buffer.VectorisedView
+ if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ if v := pkt.TransportHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if pkt.LinkHeader().View().IsEmpty() {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
} else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
- combinedVV := linkHeader.ToVectorisedView()
- combinedVV.Append(pkt.Data)
- packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
@@ -448,7 +525,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
// State implements socket.Socket.State.
-func (ep *endpoint) State() uint32 {
+func (*endpoint) State() uint32 {
return 0
}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 766c7648e..e37c00523 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -63,6 +63,7 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -83,6 +84,8 @@ type endpoint struct {
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -108,6 +111,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
}
// Override with stack defaults.
@@ -182,10 +186,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -263,7 +263,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -353,19 +353,24 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
+ if e.hdrIncluded {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
- }); err != nil {
+ })
+ if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
return 0, nil, err
}
} else {
- hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- Owner: e.owner,
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ pkt.Owner = e.owner
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: e.TransProto,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
return 0, nil, err
}
}
@@ -443,12 +448,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen implements tcpip.Endpoint.Listen.
-func (e *endpoint) Listen(backlog int) *tcpip.Error {
+func (*endpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept implements tcpip.Endpoint.Accept.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
@@ -458,7 +463,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
defer e.mu.Unlock()
// If a local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -479,12 +484,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// Even a connected socket doesn't return a remote address.
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -507,12 +512,31 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
@@ -561,9 +585,12 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
return nil
default:
@@ -577,6 +604,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.KeepaliveEnabledOption:
return false, nil
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -616,8 +649,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
@@ -667,16 +707,17 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
// slice. Save/restore doesn't support overlapping slices and will fail.
var combinedVV buffer.VectorisedView
if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
- headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader))
- headers = append(headers, pkt.NetworkHeader...)
- headers = append(headers, pkt.TransportHeader...)
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(network)+len(transport))
+ headers = append(headers, network...)
+ headers = append(headers, transport...)
combinedVV = headers.ToVectorisedView()
} else {
- combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView()
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
}
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
@@ -709,3 +750,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+
+func (*endpoint) LastError() *tcpip.Error {
+ return nil
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 18ff89ffc..518449602 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -40,6 +40,8 @@ go_library(
"endpoint_state.go",
"forwarder.go",
"protocol.go",
+ "rack.go",
+ "rack_state.go",
"rcv.go",
"rcv_state.go",
"reno.go",
@@ -49,6 +51,7 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
"tcp_endpoint_list.go",
@@ -66,6 +69,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
@@ -82,6 +86,7 @@ go_test(
"dual_stack_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
+ "tcp_rack_test.go",
"tcp_sack_test.go",
"tcp_test.go",
"tcp_timestamp_test.go",
@@ -89,6 +94,7 @@ go_test(
shard_count = 10,
deps = [
":tcp",
+ "//pkg/rand",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6e00e5526..b706438bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -212,7 +212,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
- n.amss = mssForRoute(&n.route)
+ n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
n.maybeEnableTimestamp(rcvdSynOpts)
@@ -380,6 +380,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.portFlags = e.portFlags
n.boundBindToDevice = e.boundBindToDevice
n.boundPortFlags = e.boundPortFlags
+ n.userMSS = e.userMSS
}
// reserveTupleLocked reserves an accepted endpoint's tuple.
@@ -481,9 +482,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
- // TODO(b/143300739): Use the userMSS of the listening socket
- // for accepted sockets.
-
switch {
case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
@@ -514,16 +512,19 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
// Send SYN without window scaling because we currently
- // dont't encode this information in the cookie.
+ // don't encode this information in the cookie.
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
+ //
+ // Use the user supplied MSS on the listening socket for
+ // new connections, if available.
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(timeStampOffset()),
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: mssForRoute(&s.route),
+ MSS: calculateAdvertisedMSS(e.userMSS, s.route),
}
e.sendSynTCP(&s.route, tcpFields{
id: s.id,
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 81b740115..189c01c8f 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.LastError()
+ }
}
// Wait for notification.
@@ -519,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error {
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done()
- var sackEnabled SACKEnabled
+ var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
// status then just default to switching off SACK negotiation.
@@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.LastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -740,11 +746,8 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV
func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
optLen := len(tf.opts)
- hdr := &pkt.Header
- packetSize := pkt.Data.Size()
- // Initialize the header.
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
- pkt.TransportHeader = buffer.View(tcp)
+ tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen))
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
tcp.Encode(&header.TCPFields{
SrcPort: tf.id.LocalPort,
DstPort: tf.id.RemotePort,
@@ -756,8 +759,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
})
copy(tcp[header.TCPMinimumSize:], tf.opts)
- length := uint16(hdr.UsedLength() + packetSize)
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size()))
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
// This is called CHECKSUM_PARTIAL in the Linux kernel. We
@@ -795,17 +797,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
packetSize = size
}
size -= packetSize
- var pkt stack.PacketBuffer
- pkt.Header = buffer.NewPrependable(hdrSize)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: hdrSize,
+ })
pkt.Hash = tf.txHash
pkt.Owner = owner
pkt.EgressRoute = r
pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
+ pkt.NetworkProtocolNumber = r.NetProto
data.ReadToVV(&pkt.Data, packetSize)
- buildTCPHdr(r, tf, &pkt, gso)
+ buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
- pkts.PushBack(&pkt)
+ pkts.PushBack(pkt)
}
if tf.ttl == 0 {
@@ -831,12 +834,12 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac
return sendTCPBatch(r, tf, data, gso, owner)
}
- pkt := &stack.PacketBuffer{
- Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
- Data: data,
- Hash: tf.txHash,
- Owner: owner,
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen,
+ Data: data,
+ })
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
buildTCPHdr(r, tf, pkt, gso)
if tf.ttl == 0 {
@@ -895,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -1000,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
// Bootstrap the auto tuning algorithm. Starting at zero will
// result in a really large receive window after the first auto
// tuning adjustment.
@@ -1018,14 +1020,19 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.EndpointState() == StateClose {
+ s := e.EndpointState()
+ if s == StateClose {
return
}
+
+ if s.connected() {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+ }
+
// Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
e.setEndpointState(StateClose)
- e.stack.Stats().TCP.CurrentConnected.Decrement()
- e.stack.Stats().TCP.EstablishedClosed.Increment()
}
// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
@@ -1128,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
}
cont, err := e.handleSegment(s)
+ s.decRef()
if err != nil {
- s.decRef()
return err
}
if !cont {
- s.decRef()
return nil
}
}
@@ -1155,13 +1161,18 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
return nil
}
-// handleSegment handles a given segment and notifies the worker goroutine if
-// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
- // Invoke the tcp probe if installed.
+func (e *endpoint) probeSegment() {
if e.probe != nil {
e.probe(e.completeState())
}
+}
+
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed. The tcp probe function will update
+ // the TCPEndpointState after the segment is processed.
+ defer e.probeSegment()
if s.flagIsSet(header.TCPFlagRst) {
if ok, err := e.handleReset(s); !ok {
@@ -1208,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
return true, nil
}
+ // Increase counter if after processing the segment we would potentially
+ // advertise a zero window.
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
@@ -1220,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
- s.decRef()
return false, nil
}
@@ -1412,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.rcv.nonZeroWindow()
}
- if n&notifyReceiveWindowChanged != 0 {
- e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
- }
-
if n&notifyMTUChanged != 0 {
e.sndBufMu.Lock()
count := e.packetTooBigCount
@@ -1690,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 || n&notifyAbort != 0 {
+ if n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 804e95aea..560b4904c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -78,16 +78,15 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv4(t, c.GetPacket(), ackCheckers...)
// Wait for connection to be established.
select {
case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -186,16 +185,15 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
// Wait for connection to be established.
select {
case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -285,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -321,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -354,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
@@ -373,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -494,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
@@ -512,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept()
+ var addr tcpip.FullAddress
+ nep, _, err := c.EP.Accept(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(&addr)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -528,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) {
}
}
+ if addr.Addr != context.TestV6Addr {
+ t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
+ }
+
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Fatalf("GetSockOpt failed failed: %v", err)
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestV6Addr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
+ t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err)
}
}
@@ -568,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) {
func testV4ListenClose(t *testing.T, c *context.Context) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
const n = uint16(32)
@@ -612,12 +605,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
+ nep, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- nep, _, err = c.EP.Accept()
+ nep, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index bd3ec5a8d..ae817091a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,6 +63,17 @@ const (
StateClosing
)
+const (
+ // rcvAdvWndScale is used to split the available socket buffer into
+ // application buffer and the window to be advertised to the peer. This is
+ // currently hard coded to split the available space equally.
+ rcvAdvWndScale = 1
+
+ // SegOverheadFactor is used to multiply the value provided by the
+ // user on a SetSockOpt for setting the socket send/receive buffer sizes.
+ SegOverheadFactor = 2
+)
+
// connected returns true when s is one of the states representing an
// endpoint connected to a peer.
func (s EndpointState) connected() bool {
@@ -149,7 +160,6 @@ func (s EndpointState) String() string {
// Reasons for notifying the protocol goroutine.
const (
notifyNonZeroReceiveWindow = 1 << iota
- notifyReceiveWindowChanged
notifyClose
notifyMTUChanged
notifyDrain
@@ -384,13 +394,26 @@ type endpoint struct {
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- rcvBufSize int
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ // rcvBufSize is the total size of the receive buffer.
+ rcvBufSize int
+ // rcvBufUsed is the actual number of payload bytes held in the receive buffer
+ // not counting any overheads of the segments itself. NOTE: This will always
+ // be strictly <= rcvMemUsed below.
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
+ // rcvMemUsed tracks the total amount of memory in use by received segments
+ // held in rcvList, pendingRcvdSegments and the segment queue. This is used to
+ // compute the window and the actual available buffer space. This is distinct
+ // from rcvBufUsed above which is the actual number of payload bytes held in
+ // the buffer not including any segment overheads.
+ //
+ // rcvMemUsed must be accessed atomically.
+ rcvMemUsed int32
+
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
mu sync.Mutex `state:"nosave"`
@@ -449,10 +472,11 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
- //
- // recentTS must be read/written atomically.
recentTS uint32
+ // recentTSTime is the unix time when we updated recentTS last.
+ recentTSTime time.Time `state:".(unixTime)"`
+
// tsOffset is a randomized offset added to the value of the
// TSVal field in the timestamp option.
tsOffset uint32
@@ -653,6 +677,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -666,7 +693,8 @@ func (e *endpoint) UniqueID() uint64 {
// r, it will be used; otherwise, the maximum possible MSS will be used.
func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
// The maximum possible MSS is dependent on the route.
- maxMSS := mssForRoute(&r)
+ // TODO(b/143359391): Respect TCP Min and Max size.
+ maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
if userMSS != 0 && userMSS < maxMSS {
return userMSS
@@ -795,15 +823,15 @@ func (e *endpoint) EndpointState() EndpointState {
return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
}
-// setRecentTimestamp atomically sets the recentTS field to the
-// provided value.
+// setRecentTimestamp sets the recentTS field to the provided value.
func (e *endpoint) setRecentTimestamp(recentTS uint32) {
- atomic.StoreUint32(&e.recentTS, recentTS)
+ e.recentTS = recentTS
+ e.recentTSTime = time.Now()
}
-// recentTimestamp atomically reads and returns the value of the recentTS field.
+// recentTimestamp returns the value of the recentTS field.
func (e *endpoint) recentTimestamp() uint32 {
- return atomic.LoadUint32(&e.recentTS)
+ return e.recentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -847,12 +875,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
maxSynRetries: DefaultSynRetries,
}
- var ss SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
e.sndBufSize = ss.Default
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
e.rcvBufSize = rs.Default
}
@@ -862,12 +890,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.cc = cs
}
- var mrb tcpip.ModerateReceiveBufferOption
+ var mrb tcpip.TCPModerateReceiveBufferOption
if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
e.rcvAutoParams.disabled = !bool(mrb)
}
- var de DelayEnabled
+ var de tcpip.TCPDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
e.SetSockOptBool(tcpip.DelayOption, true)
}
@@ -886,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.probe = p
}
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
@@ -899,10 +927,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result := waiter.EventMask(0)
switch e.EndpointState() {
- case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
+ case StateInitial, StateBound:
+ // This prevents blocking of new sockets which are not
+ // connected when SO_LINGER is set.
+ result |= waiter.EventHUp
+
+ case StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case StateClose, StateError:
+ case StateClose, StateError, StateTimeWait:
// Ready for anything.
result = mask
@@ -1005,6 +1038,26 @@ func (e *endpoint) Close() {
return
}
+ if e.linger.Enabled && e.linger.Timeout == 0 {
+ s := e.EndpointState()
+ isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv
+ if isResetState {
+ // Close the endpoint without doing full shutdown and
+ // send a RST.
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.closeNoShutdownLocked()
+
+ // Wake up worker to close the endpoint.
+ switch s {
+ case StateSynRecv:
+ e.notifyProtocolGoroutine(notifyClose)
+ default:
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ return
+ }
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -1050,6 +1103,8 @@ func (e *endpoint) closeNoShutdownLocked() {
e.notifyProtocolGoroutine(notifyClose)
} else {
e.transitionToStateCloseLocked()
+ // Notify that the endpoint is closed.
+ e.waiterQueue.Notify(waiter.EventHUp)
}
}
@@ -1104,10 +1159,16 @@ func (e *endpoint) cleanupLocked() {
tcpip.DeleteDanglingEndpoint(e)
}
+// wndFromSpace returns the window that we can advertise based on the available
+// receive buffer space.
+func wndFromSpace(space int) int {
+ return space / (1 << rcvAdvWndScale)
+}
+
// initialReceiveWindow returns the initial receive window to advertise in the
// SYN/SYN-ACK.
func (e *endpoint) initialReceiveWindow() int {
- rcvWnd := e.receiveBufferAvailable()
+ rcvWnd := wndFromSpace(e.receiveBufferAvailable())
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
@@ -1184,14 +1245,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = rcvWnd
- availAfter := e.receiveBufferAvailableLocked()
- mask := uint32(notifyReceiveWindowChanged)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
- e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -1209,6 +1268,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+func (e *endpoint) LastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1260,18 +1327,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
v := views[s.viewToDeliver]
s.viewToDeliver++
+ var delta int
if s.viewToDeliver >= len(views) {
e.rcvList.Remove(s)
+ // We only free up receive buffer space when the segment is released as the
+ // segment is still holding on to the views even though some views have been
+ // read out to the user.
+ delta = s.segMemSize()
s.decRef()
}
e.rcvBufUsed -= len(v)
-
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1284,14 +1355,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
- // The endpoint cannot be written to if it's not connected.
- if !e.EndpointState().connected() {
- switch e.EndpointState() {
- case StateError:
- return 0, e.HardError
- default:
- return 0, tcpip.ErrClosedForSend
- }
+ switch s := e.EndpointState(); {
+ case s == StateError:
+ return 0, e.HardError
+ case !s.connecting() && !s.connected():
+ return 0, tcpip.ErrClosedForSend
+ case s.connecting():
+ // As per RFC793, page 56, a send request arriving when in connecting
+ // state, can be queued to be completed after the state becomes
+ // connected. Return an error code for the caller of endpoint Write to
+ // try again, until the connection handshake is complete.
+ return 0, tcpip.ErrWouldBlock
}
// Check if the connection has already been closed for sends.
@@ -1445,11 +1519,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
}
// windowCrossedACKThresholdLocked checks if the receive window to be announced
-// now would be under aMSS or under half receive buffer, whichever smaller. This
-// is useful as a receive side silly window syndrome prevention mechanism. If
-// window grows to reasonable value, we should send ACK to the sender to inform
-// the rx space is now large. We also want ensure a series of small read()'s
-// won't trigger a flood of spurious tiny ACK's.
+// would be under aMSS or under the window derived from half receive buffer,
+// whichever smaller. This is useful as a receive side silly window syndrome
+// prevention mechanism. If window grows to reasonable value, we should send ACK
+// to the sender to inform the rx space is now large. We also want ensure a
+// series of small read()'s won't trigger a flood of spurious tiny ACK's.
//
// For large receive buffers, the threshold is aMSS - once reader reads more
// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
@@ -1460,17 +1534,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := e.receiveBufferAvailableLocked()
+ newAvail := wndFromSpace(e.receiveBufferAvailableLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
}
-
threshold := int(e.amss)
- if threshold > e.rcvBufSize/2 {
- threshold = e.rcvBufSize / 2
+ // rcvBufFraction is the inverse of the fraction of receive buffer size that
+ // is used to decide if the available buffer space is now above it.
+ const rcvBufFraction = 2
+ if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ threshold = wndThreshold
}
-
switch {
case oldAvail < threshold && newAvail >= threshold:
return true, true
@@ -1589,21 +1664,34 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
- var rs ReceiveBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
+ }
+
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < rs.Min {
v = rs.Min
}
- if v > rs.Max {
- v = rs.Max
- }
+ } else {
+ v = math.MaxInt32
}
- mask := uint32(notifyReceiveWindowChanged)
-
e.LockUser()
e.rcvListMu.Lock()
@@ -1617,14 +1705,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
v = 1 << scale
}
- // Make sure 2*size doesn't overflow.
- if v > math.MaxInt32/2 {
- v = math.MaxInt32 / 2
- }
-
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = v
- availAfter := e.receiveBufferAvailableLocked()
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvAutoParams.disabled = true
@@ -1632,24 +1715,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
e.rcvListMu.Unlock()
e.UnlockUser()
- e.notifyProtocolGoroutine(mask)
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
- var ss SendBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ var ss tcpip.TCPSendBufferSizeRangeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err))
+ }
+
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < ss.Min {
v = ss.Min
}
- if v > ss.Max {
- v = ss.Max
- }
+ } else {
+ v = math.MaxInt32
}
e.sndBufMu.Lock()
@@ -1682,7 +1772,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
if v < rs.Min/2 {
v = rs.Min / 2
@@ -1696,10 +1786,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case tcpip.BindToDeviceOption:
- id := tcpip.NICID(v)
+ case *tcpip.BindToDeviceOption:
+ id := tcpip.NICID(*v)
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
@@ -1707,40 +1797,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.bindToDevice = id
e.UnlockUser()
- case tcpip.KeepaliveIdleOption:
+ case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
- e.keepalive.idle = time.Duration(v)
+ e.keepalive.idle = time.Duration(*v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case tcpip.KeepaliveIntervalOption:
+ case *tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
- e.keepalive.interval = time.Duration(v)
+ e.keepalive.interval = time.Duration(*v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- case tcpip.OutOfBandInlineOption:
+ case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
- case tcpip.TCPUserTimeoutOption:
+ case *tcpip.TCPUserTimeoutOption:
e.LockUser()
- e.userTimeout = time.Duration(v)
+ e.userTimeout = time.Duration(*v)
e.UnlockUser()
- case tcpip.CongestionControlOption:
+ case *tcpip.CongestionControlOption:
// Query the available cc algorithms in the stack and
// validate that the specified algorithm is actually
// supported in the stack.
- var avail tcpip.AvailableCongestionControlOption
+ var avail tcpip.TCPAvailableCongestionControlOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
return err
}
availCC := strings.Split(string(avail), " ")
for _, cc := range availCC {
- if v == tcpip.CongestionControlOption(cc) {
+ if *v == tcpip.CongestionControlOption(cc) {
e.LockUser()
state := e.EndpointState()
- e.cc = v
+ e.cc = *v
switch state {
case StateEstablished:
if e.EndpointState() == state {
@@ -1756,33 +1846,43 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// control algorithm is specified.
return tcpip.ErrNoSuchFile
- case tcpip.TCPLingerTimeoutOption:
+ case *tcpip.TCPLingerTimeoutOption:
e.LockUser()
- if v < 0 {
+
+ switch {
+ case *v < 0:
// Same as effectively disabling TCPLinger timeout.
- v = 0
- }
- var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
- if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
- // We were unable to retrieve a stack config, just use
- // the DefaultTCPLingerTimeout.
- if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
- stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
+ *v = -1
+ case *v == 0:
+ // Same as the stack default.
+ var stackLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err))
}
+ *v = stackLingerTimeout
+ case *v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout):
+ // Cap it to Stack's default TCP_LINGER2 timeout.
+ *v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
+ default:
}
- // Cap it to the stack wide TCPLinger timeout.
- if v > stkTCPLingerTimeout {
- v = stkTCPLingerTimeout
- }
- e.tcpLingerTimeout = time.Duration(v)
+
+ e.tcpLingerTimeout = time.Duration(*v)
e.UnlockUser()
- case tcpip.TCPDeferAcceptOption:
+ case *tcpip.TCPDeferAcceptOption:
e.LockUser()
- if time.Duration(v) > MaxRTO {
- v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ if time.Duration(*v) > MaxRTO {
+ *v = tcpip.TCPDeferAcceptOption(MaxRTO)
}
- e.deferAccept = time.Duration(v)
+ e.deferAccept = time.Duration(*v)
+ e.UnlockUser()
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.LockUser()
+ e.linger = *v
e.UnlockUser()
default:
@@ -1896,6 +1996,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := header.TCPDefaultMSS
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1938,15 +2043,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
-
case *tcpip.BindToDeviceOption:
e.LockUser()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
@@ -1998,6 +2096,24 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.TCPDeferAcceptOption(e.deferAccept)
e.UnlockUser()
+ case *tcpip.OriginalDestinationOption:
+ e.LockUser()
+ ipt := e.stack.IPTables()
+ addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
+ e.UnlockUser()
+ if err != nil {
+ return err
+ }
+ *o = tcpip.OriginalDestinationOption{
+ Addr: addr,
+ Port: port,
+ }
+
+ case *tcpip.LingerOption:
+ e.LockUser()
+ *o = e.linger
+ e.UnlockUser()
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2125,12 +2241,66 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
h.Write(portBuf)
portOffset := h.Sum32()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err))
+ }
+
+ reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal
+ if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
+ switch netProto {
+ case header.IPv4ProtocolNumber:
+ reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ case header.IPv6ProtocolNumber:
+ reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ }
+ }
+
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
- return false, nil
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ if err != tcpip.ErrPortInUse || !reuse {
+ return false, nil
+ }
+ transEPID := e.ID
+ transEPID.LocalPort = p
+ // Check if an endpoint is registered with demuxer in TIME-WAIT and if
+ // we can reuse it. If we can't find a transport endpoint then we just
+ // skip using this port as it's possible that either an endpoint has
+ // bound the port but not registered with demuxer yet (no listen/connect
+ // done yet) or the reservation was freed between the check above and
+ // the FindTransportEndpoint below. But rather than retry the same port
+ // we just skip it and move on.
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ if transEP == nil {
+ // ReservePort failed but there is no registered endpoint with
+ // demuxer. Which indicates there is at least some endpoint that has
+ // bound the port.
+ return false, nil
+ }
+
+ tcpEP := transEP.(*endpoint)
+ tcpEP.LockUser()
+ // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
+ // less than 1 second has elapsed since its recentTS was updated then
+ // we cannot reuse the port.
+ if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ tcpEP.UnlockUser()
+ return false, nil
+ }
+ // Since the endpoint is in TIME-WAIT it should be safe to acquire its
+ // Lock while holding the lock for this endpoint as endpoints in
+ // TIME-WAIT do not acquire locks on other endpoints.
+ tcpEP.workerCleanup = false
+ tcpEP.cleanupLocked()
+ tcpEP.notifyProtocolGoroutine(notifyAbort)
+ tcpEP.UnlockUser()
+ // Now try and Reserve again if it fails then we skip.
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
+ return false, nil
+ }
}
id := e.ID
@@ -2368,7 +2538,9 @@ func (e *endpoint) startAcceptedLoop() {
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+//
+// addr if not-nil will contain the peer address of the returned endpoint.
+func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2390,6 +2562,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
+ if peerAddr != nil {
+ *peerAddr = n.getRemoteAddress()
+ }
return n, n.waiterQueue, nil
}
@@ -2426,47 +2601,45 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
- if err != nil {
- return err
- }
-
- e.boundBindToDevice = e.bindToDevice
- e.boundPortFlags = e.portFlags
- e.isPortReserved = true
- e.effectiveNetProtos = netProtos
- e.ID.LocalPort = port
-
- // Any failures beyond this point must remove the port registration.
- defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
- if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
- e.isPortReserved = false
- e.effectiveNetProtos = nil
- e.ID.LocalPort = 0
- e.ID.LocalAddress = ""
- e.boundNICID = 0
- e.boundBindToDevice = 0
- e.boundPortFlags = ports.Flags{}
- }
- }(e.boundPortFlags, e.boundBindToDevice)
-
+ var nic tcpip.NICID
// If an address is specified, we must ensure that it's one of our
// local addresses.
if len(addr.Addr) != 0 {
- nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
return tcpip.ErrBadLocalAddress
}
-
- e.boundNICID = nic
e.ID.LocalAddress = addr.Addr
}
- if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ id := e.ID
+ id.LocalPort = p
+ // CheckRegisterTransportEndpoint should only return an error if there is a
+ // listening endpoint bound with the same id and portFlags and bindToDevice
+ // options.
+ //
+ // NOTE: Only listening and connected endpoint register with
+ // demuxer. Further connected endpoints always have a remote
+ // address/port. Hence this will only return an error if there is a matching
+ // listening endpoint.
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ return false
+ }
+ return true
+ })
+ if err != nil {
return err
}
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
+ e.boundNICID = nic
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.ID.LocalPort = port
+
// Mark endpoint as bound.
e.setEndpointState(StateBound)
@@ -2494,11 +2667,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
+ return e.getRemoteAddress(), nil
+}
+
+func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
return tcpip.FullAddress{
Addr: e.ID.RemoteAddress,
Port: e.ID.RemotePort,
NIC: e.boundNICID,
- }, nil
+ }
}
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
@@ -2531,6 +2708,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
@@ -2557,13 +2746,8 @@ func (e *endpoint) updateSndBufferUsage(v int) {
func (e *endpoint) readyToRead(s *segment) {
e.rcvListMu.Lock()
if s != nil {
+ e.rcvBufUsed += s.payloadSize()
s.incRef()
- e.rcvBufUsed += s.data.Size()
- // Increase counter if the receive window falls down below MSS
- // or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
@@ -2578,15 +2762,17 @@ func (e *endpoint) readyToRead(s *segment) {
func (e *endpoint) receiveBufferAvailableLocked() int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
- if e.rcvBufUsed >= e.rcvBufSize {
+ memUsed := e.receiveMemUsed()
+ if memUsed >= e.rcvBufSize {
return 0
}
- return e.rcvBufSize - e.rcvBufUsed
+ return e.rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
-// receive buffer.
+// receive buffer based on the actual memory used by all segments held in
+// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
e.rcvListMu.Lock()
available := e.receiveBufferAvailableLocked()
@@ -2594,16 +2780,37 @@ func (e *endpoint) receiveBufferAvailable() int {
return available
}
+// receiveBufferUsed returns the amount of in-use receive buffer.
+func (e *endpoint) receiveBufferUsed() int {
+ e.rcvListMu.Lock()
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+ return used
+}
+
+// receiveBufferSize returns the current size of the receive buffer.
func (e *endpoint) receiveBufferSize() int {
e.rcvListMu.Lock()
size := e.rcvBufSize
e.rcvListMu.Unlock()
-
return size
}
+// receiveMemUsed returns the total memory in use by segments held by this
+// endpoint.
+func (e *endpoint) receiveMemUsed() int {
+ return int(atomic.LoadInt32(&e.rcvMemUsed))
+}
+
+// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed.
+func (e *endpoint) updateReceiveMemUsed(delta int) {
+ atomic.AddInt32(&e.rcvMemUsed, int32(delta))
+}
+
+// maxReceiveBufferSize returns the stack wide maximum receive buffer size for
+// an endpoint.
func (e *endpoint) maxReceiveBufferSize() int {
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
// As a fallback return the hardcoded max buffer size.
return MaxBufferSize
@@ -2650,15 +2857,14 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.tsOffset)
+ return tcpTimeStamp(time.Now(), e.tsOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
// not inlined above as it's used when SYN cookies are in use and endpoint
// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(offset uint32) uint32 {
- now := time.Now()
- return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+func tcpTimeStamp(curTime time.Time, offset uint32) uint32 {
+ return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset
}
// timeStampOffset returns a randomized timestamp offset to be used when sending
@@ -2684,7 +2890,7 @@ func timeStampOffset() uint32 {
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
- var v SACKEnabled
+ var v tcpip.TCPSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return
@@ -2753,7 +2959,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
RcvAcc: e.rcv.rcvAcc,
RcvWndScale: e.rcv.rcvWndScale,
PendingBufUsed: e.rcv.pendingBufUsed,
- PendingBufSize: e.rcv.pendingBufSize,
}
// Copy sender state.
@@ -2801,6 +3006,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
WEst: cubic.wEst,
}
}
+
+ rc := e.snd.rc
+ s.Sender.RACKState = stack.TCPRACKState{
+ XmitTime: rc.xmitTime,
+ EndSequence: rc.endSequence,
+ FACK: rc.fack,
+ RTT: rc.rtt,
+ }
return s
}
@@ -2869,8 +3082,3 @@ func (e *endpoint) Wait() {
<-notifyCh
}
}
-
-func mssForRoute(r *stack.Route) uint16 {
- // TODO(b/143359391): Respect TCP Min and Max size.
- return uint16(r.MTU() - header.TCPMinimumSize)
-}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index abf1ac5c9..b25431467 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() {
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
- e.segmentQueue.setLimit(0)
+ e.segmentQueue.freeze()
e.mu.Lock()
defer e.mu.Unlock()
@@ -178,18 +178,18 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
- var ss SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
@@ -309,6 +309,16 @@ func (e *endpoint) loadLastError(s string) {
e.lastError = tcpip.StringToError(s)
}
+// saveRecentTSTime is invoked by stateify.
+func (e *endpoint) saveRecentTSTime() unixTime {
+ return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
+}
+
+// loadRecentTSTime is invoked by stateify.
+func (e *endpoint) loadRecentTSTime(unix unixTime) {
+ e.recentTSTime = time.Unix(unix.second, unix.nano)
+}
+
// saveHardError is invoked by stateify.
func (e *EndpointInfo) saveHardError() string {
if e.HardError == nil {
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b34e47bbd..5bce73605 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -12,16 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package tcp contains the implementation of the TCP transport protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.NewProtocol() as one of the
-// transport protocols when calling stack.New(). Then endpoints can be created
-// by passing tcp.ProtocolNumber as the transport protocol number when calling
-// Stack.NewEndpoint().
+// Package tcp contains the implementation of the TCP transport protocol.
package tcp
import (
- "fmt"
"runtime"
"strings"
"time"
@@ -30,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -62,6 +57,10 @@ const (
// FIN_WAIT_2 state before being marked closed.
DefaultTCPLingerTimeout = 60 * time.Second
+ // MaxTCPLingerTimeout is the maximum amount of time that sockets
+ // linger in FIN_WAIT_2 state before being marked closed.
+ MaxTCPLingerTimeout = 120 * time.Second
+
// DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
// in TIME_WAIT state before being marked closed.
DefaultTCPTimeWaitTimeout = 60 * time.Second
@@ -76,31 +75,6 @@ const (
ccCubic = "cubic"
)
-// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to
-// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
-type SACKEnabled bool
-
-// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
-// enable/disable Nagle's algorithm in TCP.
-type DelayEnabled bool
-
-// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
-// to get/set the default, min and max TCP send buffer sizes.
-type SendBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
-// ReceiveBufferSizeOption is used by
-// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
-// TCP receive buffer sizes.
-type ReceiveBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
// value is protected by a mutex so that we can increment only when it's
// guaranteed not to go above a threshold.
@@ -159,16 +133,20 @@ func (s *synRcvdCounter) Threshold() uint64 {
}
type protocol struct {
+ stack *stack.Stack
+
mu sync.RWMutex
sackEnabled bool
+ recovery tcpip.TCPRecovery
delayEnabled bool
- sendBufferSize SendBufferSizeOption
- recvBufferSize ReceiveBufferSizeOption
+ sendBufferSize tcpip.TCPSendBufferSizeRangeOption
+ recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
- tcpLingerTimeout time.Duration
- tcpTimeWaitTimeout time.Duration
+ lingerTimeout time.Duration
+ timeWaitTimeout time.Duration
+ timeWaitReuse tcpip.TCPTimeWaitReuseOption
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
@@ -183,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid tcp packet size.
@@ -220,21 +198,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
- return false
+ return stack.UnknownDestinationPacketMalformed
}
- // There's nothing to do if this is already a reset packet.
- if s.flagIsSet(header.TCPFlagRst) {
- return true
+ if !s.flagIsSet(header.TCPFlagRst) {
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
}
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
- return true
+ return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
@@ -272,43 +249,49 @@ func replyWithReset(s *segment, tos, ttl uint8) {
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case SACKEnabled:
+ case *tcpip.TCPSACKEnabled:
+ p.mu.Lock()
+ p.sackEnabled = bool(*v)
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPRecovery:
p.mu.Lock()
- p.sackEnabled = bool(v)
+ p.recovery = *v
p.mu.Unlock()
return nil
- case DelayEnabled:
+ case *tcpip.TCPDelayEnabled:
p.mu.Lock()
- p.delayEnabled = bool(v)
+ p.delayEnabled = bool(*v)
p.mu.Unlock()
return nil
- case SendBufferSizeOption:
+ case *tcpip.TCPSendBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.sendBufferSize = v
+ p.sendBufferSize = *v
p.mu.Unlock()
return nil
- case ReceiveBufferSizeOption:
+ case *tcpip.TCPReceiveBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.recvBufferSize = v
+ p.recvBufferSize = *v
p.mu.Unlock()
return nil
- case tcpip.CongestionControlOption:
+ case *tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
- if string(v) == c {
+ if string(*v) == c {
p.mu.Lock()
- p.congestionControl = string(v)
+ p.congestionControl = string(*v)
p.mu.Unlock()
return nil
}
@@ -317,66 +300,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
// is specified.
return tcpip.ErrNoSuchFile
- case tcpip.ModerateReceiveBufferOption:
+ case *tcpip.TCPModerateReceiveBufferOption:
p.mu.Lock()
- p.moderateReceiveBuffer = bool(v)
+ p.moderateReceiveBuffer = bool(*v)
p.mu.Unlock()
return nil
- case tcpip.TCPLingerTimeoutOption:
- if v < 0 {
- v = 0
- }
+ case *tcpip.TCPLingerTimeoutOption:
p.mu.Lock()
- p.tcpLingerTimeout = time.Duration(v)
+ if *v < 0 {
+ p.lingerTimeout = 0
+ } else {
+ p.lingerTimeout = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPTimeWaitTimeoutOption:
- if v < 0 {
- v = 0
- }
+ case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.Lock()
- p.tcpTimeWaitTimeout = time.Duration(v)
+ if *v < 0 {
+ p.timeWaitTimeout = 0
+ } else {
+ p.timeWaitTimeout = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPMinRTOOption:
- if v < 0 {
- v = tcpip.TCPMinRTOOption(MinRTO)
+ case *tcpip.TCPTimeWaitReuseOption:
+ if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly {
+ return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.minRTO = time.Duration(v)
+ p.timeWaitReuse = *v
p.mu.Unlock()
return nil
- case tcpip.TCPMaxRTOOption:
- if v < 0 {
- v = tcpip.TCPMaxRTOOption(MaxRTO)
+ case *tcpip.TCPMinRTOOption:
+ p.mu.Lock()
+ if *v < 0 {
+ p.minRTO = MinRTO
+ } else {
+ p.minRTO = time.Duration(*v)
}
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
- p.maxRTO = time.Duration(v)
+ if *v < 0 {
+ p.maxRTO = MaxRTO
+ } else {
+ p.maxRTO = time.Duration(*v)
+ }
p.mu.Unlock()
return nil
- case tcpip.TCPMaxRetriesOption:
+ case *tcpip.TCPMaxRetriesOption:
p.mu.Lock()
- p.maxRetries = uint32(v)
+ p.maxRetries = uint32(*v)
p.mu.Unlock()
return nil
- case tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPSynRcvdCountThresholdOption:
p.mu.Lock()
- p.synRcvdCount.SetThreshold(uint64(v))
+ p.synRcvdCount.SetThreshold(uint64(*v))
p.mu.Unlock()
return nil
- case tcpip.TCPSynRetriesOption:
- if v < 1 || v > 255 {
+ case *tcpip.TCPSynRetriesOption:
+ if *v < 1 || *v > 255 {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
- p.synRetries = uint8(v)
+ p.synRetries = uint8(*v)
p.mu.Unlock()
return nil
@@ -386,27 +382,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
// Option implements stack.TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *SACKEnabled:
+ case *tcpip.TCPSACKEnabled:
+ p.mu.RLock()
+ *v = tcpip.TCPSACKEnabled(p.sackEnabled)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPRecovery:
p.mu.RLock()
- *v = SACKEnabled(p.sackEnabled)
+ *v = tcpip.TCPRecovery(p.recovery)
p.mu.RUnlock()
return nil
- case *DelayEnabled:
+ case *tcpip.TCPDelayEnabled:
p.mu.RLock()
- *v = DelayEnabled(p.delayEnabled)
+ *v = tcpip.TCPDelayEnabled(p.delayEnabled)
p.mu.RUnlock()
return nil
- case *SendBufferSizeOption:
+ case *tcpip.TCPSendBufferSizeRangeOption:
p.mu.RLock()
*v = p.sendBufferSize
p.mu.RUnlock()
return nil
- case *ReceiveBufferSizeOption:
+ case *tcpip.TCPReceiveBufferSizeRangeOption:
p.mu.RLock()
*v = p.recvBufferSize
p.mu.RUnlock()
@@ -418,27 +420,33 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
- case *tcpip.AvailableCongestionControlOption:
+ case *tcpip.TCPAvailableCongestionControlOption:
p.mu.RLock()
- *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.RUnlock()
return nil
- case *tcpip.ModerateReceiveBufferOption:
+ case *tcpip.TCPModerateReceiveBufferOption:
p.mu.RLock()
- *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
+ *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer)
p.mu.RUnlock()
return nil
case *tcpip.TCPLingerTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout)
p.mu.RUnlock()
return nil
case *tcpip.TCPTimeWaitTimeoutOption:
p.mu.RLock()
- *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitReuseOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
p.mu.RUnlock()
return nil
@@ -495,46 +503,34 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter {
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
- if !ok {
- return false
- }
-
- // If the header has options, pull those up as well.
- if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
- hdr, ok = pkt.Data.PullUp(offset)
- if !ok {
- panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset))
- }
- }
-
- pkt.TransportHeader = hdr
- pkt.Data.TrimFront(len(hdr))
- return true
+ return parse.TCP(pkt)
}
// NewProtocol returns a TCP transport protocol.
-func NewProtocol() stack.TransportProtocol {
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
p := protocol{
- sendBufferSize: SendBufferSizeOption{
+ stack: s,
+ sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{
Min: MinBufferSize,
Default: DefaultSendBufferSize,
Max: MaxBufferSize,
},
- recvBufferSize: ReceiveBufferSizeOption{
+ recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{
Min: MinBufferSize,
Default: DefaultReceiveBufferSize,
Max: MaxBufferSize,
},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
- tcpLingerTimeout: DefaultTCPLingerTimeout,
- tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ lingerTimeout: DefaultTCPLingerTimeout,
+ timeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
maxRetries: MaxRetries,
+ recovery: tcpip.TCPRACKLossDetection,
}
p.dispatcher.init(runtime.GOMAXPROCS(0))
return &p
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
new file mode 100644
index 000000000..439932595
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// RACK is a loss detection algorithm used in TCP to detect packet loss and
+// reordering using transmission timestamp of the packets instead of packet or
+// sequence counts. To use RACK, SACK should be enabled on the connection.
+
+// rackControl stores the rack related fields.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1
+//
+// +stateify savable
+type rackControl struct {
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
+
+ // endSequence is the ending TCP sequence number of rackControl.seg.
+ endSequence seqnum.Value
+
+ // fack is the highest selectively or cumulatively acknowledged
+ // sequence.
+ fack seqnum.Value
+
+ // minRTT is the estimated minimum RTT of the connection.
+ minRTT time.Duration
+
+ // rtt is the RTT of the most recently delivered packet on the
+ // connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ rtt time.Duration
+}
+
+// Update will update the RACK related fields when an ACK has been received.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+func (rc *rackControl) Update(seg *segment, ackSeg *segment, offset uint32) {
+ rtt := time.Now().Sub(seg.xmitTime)
+
+ // If the ACK is for a retransmitted packet, do not update if it is a
+ // spurious inference which is determined by below checks:
+ // 1. When Timestamping option is available, if the TSVal is less than the
+ // transmit time of the most recent retransmitted packet.
+ // 2. When RTT calculated for the packet is less than the smoothed RTT
+ // for the connection.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // step 2
+ if seg.xmitCount > 1 {
+ if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
+ if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) {
+ return
+ }
+ }
+ if rtt < rc.minRTT {
+ return
+ }
+ }
+
+ rc.rtt = rtt
+
+ // The sender can either track a simple global minimum of all RTT
+ // measurements from the connection, or a windowed min-filtered value
+ // of recent RTT measurements. This implementation keeps track of the
+ // simple global minimum of all RTTs for the connection.
+ if rtt < rc.minRTT || rc.minRTT == 0 {
+ rc.minRTT = rtt
+ }
+
+ // Update rc.xmitTime and rc.endSequence to the transmit time and
+ // ending sequence number of the packet which has been acknowledged
+ // most recently.
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
+ rc.xmitTime = seg.xmitTime
+ rc.endSequence = endSeq
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go
new file mode 100644
index 000000000..c9dc7e773
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack_state.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveXmitTime is invoked by stateify.
+func (rc *rackControl) saveXmitTime() unixTime {
+ return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (rc *rackControl) loadXmitTime(unix unixTime) {
+ rc.xmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index dd89a292a..48bf196d8 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -47,22 +47,24 @@ type receiver struct {
closed bool
+ // pendingRcvdSegments is bounded by the receive buffer size of the
+ // endpoint.
pendingRcvdSegments segmentHeap
- pendingBufUsed seqnum.Size
- pendingBufSize seqnum.Size
+ // pendingBufUsed tracks the total number of bytes (including segment
+ // overhead) currently queued in pendingRcvdSegments.
+ pendingBufUsed int
// Time when the last ack was received.
lastRcvdAckTime time.Time `state:".(unixTime)"`
}
-func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
lastRcvdAckTime: time.Now(),
}
}
@@ -85,15 +87,30 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- // Calculate the window size based on the available buffer space.
- receiveBufferAvailable := r.ep.receiveBufferAvailable()
- acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
- if r.rcvAcc.LessThan(acc) {
- r.rcvAcc = acc
+ avail := wndFromSpace(r.ep.receiveBufferAvailable())
+ if avail == 0 {
+ // We have no space available to accept any data, move to zero window
+ // state.
+ r.rcvWnd = 0
+ return r.rcvNxt, 0
+ }
+
+ acc := r.rcvNxt.Add(seqnum.Size(avail))
+ newWnd := r.rcvNxt.Size(acc)
+ curWnd := r.rcvNxt.Size(r.rcvAcc)
+
+ // Update rcvAcc only if new window is > previously advertised window. We
+ // should never shrink the acceptable sequence space once it has been
+ // advertised the peer. If we shrink the acceptable sequence space then we
+ // would end up dropping bytes that might already be in flight.
+ if newWnd > curWnd {
+ r.rcvAcc = r.rcvNxt.Add(newWnd)
+ } else {
+ newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
- r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
+ r.rcvWnd = newWnd
return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
}
@@ -195,7 +212,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
r.pendingRcvdSegments[i].decRef()
+
// Note that slice truncation does not allow garbage collection of
// truncated items, thus truncated items must be set to nil to avoid
// memory leaks.
@@ -268,14 +287,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// If we are in one of the shutdown states then we need to do
// additional checks before we try and process the segment.
switch state {
- case StateCloseWait:
- // If the ACK acks something not yet sent then we send an ACK.
- if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
- r.ep.snd.sendAck()
- return true, nil
- }
- fallthrough
- case StateClosing, StateLastAck:
+ case StateCloseWait, StateClosing, StateLastAck:
if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
// Just drop the segment as we have
// already received a FIN and this
@@ -284,9 +296,31 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
return true, nil
}
fallthrough
- case StateFinWait1:
- fallthrough
- case StateFinWait2:
+ case StateFinWait1, StateFinWait2:
+ // If the ACK acks something not yet sent then we send an ACK.
+ //
+ // RFC793, page 37: If the connection is in a synchronized state,
+ // (ESTABLISHED, FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK,
+ // TIME-WAIT), any unacceptable segment (out of window sequence number
+ // or unacceptable acknowledgment number) must elicit only an empty
+ // acknowledgment segment containing the current send-sequence number
+ // and an acknowledgment indicating the next sequence number expected
+ // to be received, and the connection remains in the same state.
+ //
+ // Just as on Linux, we do not apply this behavior when state is
+ // ESTABLISHED.
+ // Linux receive processing for all states except ESTABLISHED and
+ // TIME_WAIT is here where if the ACK check fails, we attempt to
+ // reply back with an ACK with correct seq/ack numbers.
+ // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L6186
+ // The ESTABLISHED state processing is here where if the ACK check
+ // fails, we ignore the packet:
+ // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591
+ if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
+ r.ep.snd.sendAck()
+ return true, nil
+ }
+
// If we are closed for reads (either due to an
// incoming FIN or the user calling shutdown(..,
// SHUT_RD) then any data past the rcvNxt should
@@ -369,10 +403,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
- // We only store the segment if it's within our buffer
- // size limit.
- if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ // We only store the segment if it's within our buffer size limit.
+ //
+ // Only use 75% of the receive buffer queue for out-of-order
+ // segments. This ensures that we always leave some space for the inorder
+ // segments to arrive allowing pending segments to be processed and
+ // delivered to the user.
+ if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 {
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed += s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +446,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed -= s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.decRef()
}
return false, nil
@@ -421,6 +463,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// Just silently drop any RST packets in TIME_WAIT. We do not support
// TIME_WAIT assasination as a result we confirm w/ fix 1 as described
// in https://tools.ietf.org/html/rfc1337#section-3.
+ //
+ // This behavior overrides RFC793 page 70 where we transition to CLOSED
+ // on receiving RST, which is also default Linux behavior.
+ // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337.
+ //
+ // As we do not yet support PAWS, we are being conservative in ignoring
+ // RSTs by default.
if s.flagIsSet(header.TCPFlagRst) {
return false, false
}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 0280892a8..13acaf753 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "fmt"
"sync/atomic"
"time"
@@ -24,6 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// queueFlags are used to indicate which queue of an endpoint a particular segment
+// belongs to. This is used to track memory accounting correctly.
+type queueFlags uint8
+
+const (
+ recvQ queueFlags = 1 << iota
+ sendQ
+)
+
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
// segment is mostly immutable, the only field allowed to change is viewToDeliver.
@@ -32,6 +42,8 @@ import (
type segment struct {
segmentEntry
refCnt int32
+ ep *endpoint
+ qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
@@ -68,7 +80,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB
route: r.Clone(),
}
s.data = pkt.Data.Clone(s.views[:])
- s.hdr = header.TCP(pkt.TransportHeader)
+ s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
return s
}
@@ -100,6 +112,8 @@ func (s *segment) clone() *segment {
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
+ ep: s.ep,
+ qFlags: s.qFlags,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool {
return s.flags&flags == flags
}
+// setOwner sets the owning endpoint for this segment. Its required
+// to be called to ensure memory accounting for receive/send buffer
+// queues is done properly.
+func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) {
+ switch qFlags {
+ case recvQ:
+ ep.updateReceiveMemUsed(s.segMemSize())
+ case sendQ:
+ // no memory account for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b", qFlags))
+ }
+ s.ep = ep
+ s.qFlags = qFlags
+}
+
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ if s.ep != nil {
+ switch s.qFlags {
+ case recvQ:
+ s.ep.updateReceiveMemUsed(-s.segMemSize())
+ case sendQ:
+ // no memory accounting for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
+ }
+ }
s.route.Release()
}
}
@@ -138,6 +178,17 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// payloadSize is the size of s.data.
+func (s *segment) payloadSize() int {
+ return s.data.Size()
+}
+
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 48a257137..54545a1b1 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -22,16 +22,16 @@ import (
//
// +stateify savable
type segmentQueue struct {
- mu sync.Mutex `state:"nosave"`
- list segmentList `state:"wait"`
- limit int
- used int
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ ep *endpoint
+ frozen bool
}
// emptyLocked determines if the queue is empty.
// Preconditions: q.mu must be held.
func (q *segmentQueue) emptyLocked() bool {
- return q.used == 0
+ return q.list.Empty()
}
// empty determines if the queue is empty.
@@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool {
return r
}
-// setLimit updates the limit. No segments are immediately dropped in case the
-// queue becomes full due to the new limit.
-func (q *segmentQueue) setLimit(limit int) {
- q.mu.Lock()
- q.limit = limit
- q.mu.Unlock()
-}
-
// enqueue adds the given segment to the queue.
//
// Returns true when the segment is successfully added to the queue, in which
@@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) {
// false if the queue is full, in which case ownership is retained by the
// caller.
func (q *segmentQueue) enqueue(s *segment) bool {
+ // q.ep.receiveBufferParams() must be called without holding q.mu to
+ // avoid lock order inversion.
+ bufSz := q.ep.receiveBufferSize()
+ used := q.ep.receiveMemUsed()
q.mu.Lock()
- r := q.used < q.limit
- if r {
+ // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
+ // is currently full).
+ allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+
+ if allow {
q.list.PushBack(s)
- q.used++
+ // Set the owner now that the endpoint owns the segment.
+ s.setOwner(q.ep, recvQ)
}
q.mu.Unlock()
- return r
+ return allow
}
// dequeue removes and returns the next segment from queue, if one exists.
@@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment {
s := q.list.Front()
if s != nil {
q.list.Remove(s)
- q.used--
}
q.mu.Unlock()
return s
}
+
+// freeze prevents any more segments from being added to the queue. i.e all
+// future segmentQueue.enqueue will return false and not add the segment to the
+// queue till the queue is unfroze with a corresponding segmentQueue.thaw call.
+func (q *segmentQueue) freeze() {
+ q.mu.Lock()
+ q.frozen = true
+ q.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and
+// allows new segments to be queued again.
+func (q *segmentQueue) thaw() {
+ q.mu.Lock()
+ q.frozen = false
+ q.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 5862c32f2..4c9a86cda 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -191,6 +191,10 @@ type sender struct {
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
+
+ // rc has the fields needed for implementing RACK loss detection
+ // algorithm.
+ rc rackControl
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -1272,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
-func (s *sender) handleRcvdSegment(seg *segment) {
+func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if s.ep.sendTSOk && seg.parsedOptions.TS {
- s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber)
}
// Insert SACKBlock information into our scoreboard.
if s.ep.sackPermitted {
- for _, sb := range seg.parsedOptions.SACKBlocks {
+ for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
// true:
// * SACK block acks data after the ack number in the
@@ -1299,27 +1303,27 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// NOTE: This check specifically excludes DSACK blocks
// which have start/end before sndUna and are used to
// indicate spurious retransmissions.
- if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
s.ep.scoreboard.Insert(sb)
- seg.hasNewSACKInfo = true
+ rcvdSeg.hasNewSACKInfo = true
}
}
s.SetPipe()
}
// Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(seg)
+ rtx := s.checkDuplicateAck(rcvdSeg)
// Stash away the current window size.
- s.sndWnd = seg.window
+ s.sndWnd = rcvdSeg.window
- ack := seg.ackNumber
+ ack := rcvdSeg.ackNumber
// Disable zero window probing if remote advertizes a non-zero receive
// window. This can be with an ACK to the zero window probe (where the
// acknumber refers to the already acknowledged byte) OR to any previously
// unacknowledged segment.
- if s.zeroWindowProbing && seg.window > 0 &&
+ if s.zeroWindowProbing && rcvdSeg.window > 0 &&
(ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
s.disableZeroWindowProbing()
}
@@ -1344,10 +1348,10 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
- if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
- elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
s.updateRTO(elapsed)
}
@@ -1380,6 +1384,11 @@ func (s *sender) handleRcvdSegment(seg *segment) {
s.writeNext = seg.Next()
}
+ // Update the RACK fields if SACK is enabled.
+ if s.ep.sackPermitted {
+ s.rc.Update(seg, rcvdSeg, s.ep.tsOffset)
+ }
+
s.writeList.Remove(seg)
// if SACK is enabled then Only reduce outstanding if
@@ -1435,7 +1444,7 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
s.sendData()
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
new file mode 100644
index 000000000..e03f101e8
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+)
+
+// TestRACKUpdate tests the RACK related fields are updated when an ACK is
+// received on a SACK enabled connection.
+func TestRACKUpdate(t *testing.T) {
+ const maxPayload = 10
+ const tsOptionSize = 12
+ const maxTCPOptionSize = 40
+
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ var xmitTime time.Time
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.RACKState is what we expect.
+ if state.Sender.RACKState.XmitTime.Before(xmitTime) {
+ t.Fatalf("RACK transmit time failed to update when an ACK is received")
+ }
+
+ gotSeq := state.Sender.RACKState.EndSequence
+ wantSeq := state.Sender.SndNxt
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ if state.Sender.RACKState.RTT == 0 {
+ t.Fatalf("RACK RTT failed to update when an ACK is received")
+ }
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ data := buffer.NewView(maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ xmitTime = time.Now()
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ c.SendAck(790, bytesRead)
+ time.Sleep(200 * time.Millisecond)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 99521f0c1..ef7f5719f 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err)
+ opt := tcpip.TCPSACKEnabled(enable)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) {
// Set the SynRcvd threshold to
// zero to force a syn cookie
// based accept to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
setStackSACKPermitted(t, c, sackEnabled)
@@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) {
// Set the SynRcvd threshold to
// zero to force a syn cookie
// based accept to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 169adb16b..5b504d0d1 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -74,8 +75,8 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted)
+ if err := ep.LastError(); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
}
// Call Connect again to retreive the handshake failure status
@@ -146,6 +147,24 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
}
}
+func TestCloseWithoutConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ c.EP.Close()
+
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
func TestTCPSegmentsSentIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -222,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetsSentNoICMP confirms that we don't get an ICMP
+// DstUnreachable packet when we try send a packet which is not part
+// of an active session.
+func TestTCPResetsSentNoICMP(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ stats := c.Stack().Stats()
+
+ // Send a SYN request for a closed port. This should elicit an RST
+ // but NOT an ICMPv4 DstUnreachable packet.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive whatever comes back.
+ b := c.GetPacket()
+ ipHdr := header.IPv4(b)
+ if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want {
+ t.Errorf("unexpected protocol, got = %d, want = %d", got, want)
+ }
+
+ // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
+ sent := stats.ICMP.V4PacketsSent
+ if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
+ t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
+ }
+}
+
// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
// a RST if an ACK is received on the listening socket for which there is no
// active handshake in progress and we are not using SYN cookies.
@@ -273,12 +324,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -291,16 +342,16 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// Lower stackwide TIME_WAIT timeout so that the reservations
// are released instantly on Close.
tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err)
}
c.EP.Close()
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
SrcPort: context.TestPort,
@@ -330,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -414,8 +465,9 @@ func TestConnectResetAfterClose(t *testing.T) {
// Set TCPLinger to 3 seconds so that sockets are marked closed
// after 3 second in FIN_WAIT2 state.
tcpLingerTimeout := 3 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err)
+ opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -428,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -470,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence number
// of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -488,8 +540,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
// after 1 second in TIME_WAIT state.
tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -509,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -545,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -592,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -613,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -673,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -725,135 +778,234 @@ func TestSimpleReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
-// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
-// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when
+// creating a new active TCP socket. It should be present in the sent TCP
// SYN segment.
-func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+func TestUserSuppliedMSSOnConnect(t *testing.T) {
const mtu = 5000
- const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS int
- expMSS uint16
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ connectAddr tcpip.Address
+ checker func(*testing.T, *context.Context, uint16, int)
+ maxMSS uint16
}{
{
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ connectAddr: context.TestAddr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
},
{
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(true)
+ },
+ connectAddr: context.TestV6Addr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
- c.Create(-1)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
- }
+ ip.createEP(c)
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+ // Set the MSS socket option.
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
- // Start connection attempt to IPv4 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("unexpected return value from Connect: %s", err)
- }
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
- // Receive SYN packet with our user supplied MSS.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
+ if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ ip.checker(t, c, test.expMSS, ws)
+ })
+ }
})
}
}
-// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
-// creating a new active IPv6 TCP socket. It should be present in the sent TCP
-// SYN segment.
-func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
- const mtu = 5000
- const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
+// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used
+// when completing the handshake for a new TCP connection from a TCP
+// listening socket. It should be present in the sent TCP SYN-ACK segment.
+func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
+ const (
+ nonSynCookieAccepts = 2
+ totalAccepts = 4
+ mtu = 5000
+ )
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ sendPkt func(*context.Context, *context.Headers)
+ checker func(*testing.T, *context.Context, uint16, uint16)
+ maxMSS uint16
}{
{
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendPacket(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
},
{
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(false)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendV6Packet(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
- c.CreateV6Endpoint(true)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err)
- }
+ ip.createEP(c)
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+ // Set the SynRcvd threshold to force a syn cookie based accept to happen.
+ opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
- // Start connection attempt to IPv6 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("unexpected return value from Connect: %s", err)
- }
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
- // Receive SYN packet with our user supplied MSS.
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ bindAddr := tcpip.FullAddress{Port: context.StackPort}
+ if err := c.EP.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s:", bindAddr, err)
+ }
+
+ if err := c.EP.Listen(totalAccepts); err != nil {
+ t.Fatalf("Listen(%d): %s:", totalAccepts, err)
+ }
+
+ // The first nonSynCookieAccepts packets sent will trigger a gorooutine
+ // based accept. The rest will trigger a cookie based accept.
+ for i := 0; i < totalAccepts; i++ {
+ // Send a SYN requests.
+ iss := seqnum.Value(i)
+ srcPort := context.TestPort + uint16(i)
+ ip.sendPkt(c, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ ip.checker(t, c, srcPort, test.expMSS)
+ }
+ })
+ }
})
}
}
-
func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -879,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
@@ -907,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
@@ -944,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
@@ -982,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestSendRstOnListenerRxAckV4(t *testing.T) {
@@ -1011,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxAckV6(t *testing.T) {
@@ -1039,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestListenShutdown tests for the listening endpoint replying with RST
@@ -1155,8 +1307,8 @@ func TestTOSV4(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1204,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1232,7 +1384,9 @@ func TestConnectBindToDevice(t *testing.T) {
c.Create(-1)
bindToDevice := tcpip.BindToDeviceOption(test.device)
- c.EP.SetSockOpt(bindToDevice)
+ if err := c.EP.SetSockOpt(&bindToDevice); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err)
+ }
// Start connection attempt.
waitEntry, _ := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
@@ -1276,68 +1430,91 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
-func TestRstOnSynSent(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
+func TestSynSent(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reset bool
+ }{
+ {"RstOnSynSent", true},
+ {"CloseOnSynSent", false},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
- }
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
- // Ensure that we've reached SynSent state
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to Error and close out
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
+ if test.reset {
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to error and close out.
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+ } else {
+ c.EP.Close()
+ }
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ if test.reset {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+ } else {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+ }
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ })
}
}
@@ -1370,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1421,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1432,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Create a new connection with initial window size of 10.
- c.CreateConnected(789, 30000, 10)
+ rcvBufSz := math.MaxUint16
+ c.CreateConnected(789, 30000, rcvBufSz)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
@@ -1454,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1475,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1495,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(793),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1537,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1552,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
// We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1606,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1620,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
@@ -1639,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence
// number of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1685,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, 10)
+ const rcvBufSz = 10
+ c.CreateConnected(789, 30000, rcvBufSz)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1696,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) {
t.Fatalf("Read failed: %s", err)
}
- // Fill up the window.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
+ // the provided buffer value by tcp.SegOverheadFactor to calculate the actual
+ // receive buffer size.
+ data := make([]byte, tcp.SegOverheadFactor*rcvBufSz)
+ for i := range data {
+ data[i] = byte(i % 255)
+ }
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -1718,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
@@ -1744,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(10),
+ checker.TCPWindow(10),
),
)
}
@@ -1756,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Start off with a window size of 10, then shrink it to 5.
- c.CreateConnected(789, 30000, 10)
-
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
- }
+ // Start off with a certain receive buffer then cut it in half and verify that
+ // the right edge of the window does not shrink.
+ // NOTE: Netstack doubles the value specified here.
+ rcvBufSize := 65536
+ iss := seqnum.Value(789)
+ // Enable window scaling with a scale of zero from our end.
+ c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1770,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) {
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
-
- // Send 3 bytes, check that the peer acknowledges them.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
- c.SendPacket(data[:3], &context.Headers{
+ // Send a 1 byte payload so that we can record the current receive window.
+ // Send a payload of half the size of rcvBufSize.
+ seqNum := iss.Add(1)
+ payload := []byte{1}
+ c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1789,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("Timed out waiting for data to arrive")
}
- // Check that data is acknowledged, and that window doesn't go to zero
- // just yet because it was previously set to 10. It must go to 7 now.
- checker.IPv4(t, c.GetPacket(),
+ // Read the 1 byte payload we just sent.
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ if got, want := payload, v; !bytes.Equal(got, want) {
+ t.Fatalf("got data: %v, want: %v", got, want)
+ }
+
+ seqNum = seqNum.Add(1)
+ // Verify that the ACK does not shrink the window.
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(7),
),
)
+ // Stash the initial window.
+ initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd))
+ // Now shrink the receive buffer to half its original size.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
+ }
- // Send 7 more bytes, check that the window fills up.
- c.SendPacket(data[3:], &context.Headers{
+ data := generateRandomPayload(t, rcvBufSize)
+ // Send a payload of half the size of rcvBufSize.
+ c.SendPacket(data[:rcvBufSize/2], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
+ // Verify that the ACK does not shrink the window.
+ pkt = c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd))
+ if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
+ t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
}
+ // Send another payload of half the size of rcvBufSize. This should fill up the
+ // socket receive buffer and we should see a zero window.
+ c.SendPacket(data[rcvBufSize/2:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
// Receive data and check it.
- read := make([]byte, 0, 10)
+ read := make([]byte, 0, rcvBufSize)
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
@@ -1842,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("got data = %v, want = %v", read, data)
}
- // Check that we get an ACK for the newly non-zero window, which is the
- // new size.
+ // Check that we get an ACK for the newly non-zero window, which is the new
+ // receive buffer size we set after the connection was established.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(5),
+ checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
),
)
}
@@ -1875,8 +2109,8 @@ func TestSimpleSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -1917,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -1939,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -1979,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) {
t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2018,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2054,19 +2288,20 @@ func TestScaledWindowAccept(t *testing.T) {
}
// Do 3-way handshake.
- c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2084,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) {
t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2135,12 +2370,12 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2165,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2180,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Set the window size such that a window scale of 4 will be used.
- const wnd = 65535 * 10
- const ws = uint32(4)
- c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
+ // Set the buffer size such that a window scale of 5 will be used.
+ const bufSz = 65535 * 10
+ const ws = uint32(5)
+ c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
// Write chunks of 50000 bytes.
- remain := wnd
+ remain := 0
sent := 0
data := make([]byte, 50000)
- for remain > len(data) {
+ // Keep writing till the window drops below len(data).
+ for {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -2201,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(remain>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Don't reduce window to zero here.
+ if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) {
+ remain = wnd << ws
+ break
+ }
}
// Make the window non-zero, but the scaled window zero.
- if remain >= 16 {
+ for remain >= 16 {
data = data[:remain-15]
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
@@ -2226,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(0),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Since the receive buffer is split between window advertisement and
+ // application data buffer the window does not always reflect the space
+ // available and actual space available can be a bit more than what is
+ // advertised in the window.
+ wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize())
+ if wnd == 0 {
+ break
+ }
+ remain = wnd << ws
}
- // Read at least 1MSS of data. An ack should be sent in response to that.
+ // Read at least 2MSS of data. An ack should be sent in response to that.
+ // Since buffer space is now split in half between window and application
+ // data we need to read more than 1 MSS(65536) of data for a non-zero window
+ // update to be sent. For 1MSS worth of window to be available we need to
+ // read at least 128KB. Since our segments above were 50KB each it means
+ // we need to read at 3 packets.
sz := 0
- for sz < defaultMTU {
+ for sz < defaultMTU*2 {
v, _, err := c.EP.Read(nil)
if err != nil {
t.Fatalf("Read failed: %s", err)
@@ -2253,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(sz>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2322,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize+1),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2345,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+11),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+11),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2393,8 +2646,8 @@ func TestDelay(t *testing.T) {
checker.PayloadLen(len(want)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2440,8 +2693,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2463,8 +2716,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2525,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2577,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2698,12 +2951,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2725,8 +2978,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ opt := tcpip.TCPSynRcvdCountThresholdOption(0)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create EP and start listening.
@@ -2753,12 +3007,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -2819,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- const wndScale = 2
+ const wndScale = 3
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
@@ -2854,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
checker.SrcPort(tcpHdr.SourcePort()),
- checker.SeqNum(tcpHdr.SequenceNumber()),
+ checker.TCPSeqNum(tcpHdr.SequenceNumber()),
checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -2875,16 +3129,16 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
// Wait for connection to be established.
select {
case <-ch:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %s", err)
+ if err := c.EP.LastError(); err != nil {
+ t.Fatalf("Connect failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -3004,8 +3258,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
defer c.Cleanup()
const numRetries = 2
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil {
- t.Fatalf("could not set protocol option MaxRetries.\n")
+ opt := tcpip.TCPMaxRetriesOption(numRetries)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3064,8 +3319,9 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err)
+ opt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3095,6 +3351,63 @@ func TestMaxRTO(t *testing.T) {
}
}
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
+ }
+}
+
func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -3110,8 +3423,8 @@ func TestFinImmediately(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3131,8 +3444,8 @@ func TestFinImmediately(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3153,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3164,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3185,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3209,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3234,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3256,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3284,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3303,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3323,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3344,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3368,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3393,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3409,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3430,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3455,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3476,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3491,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3507,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3599,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3738,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3765,8 +4078,9 @@ func TestReadAfterClosedState(t *testing.T) {
// Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
// after 1 second in TIME_WAIT state.
tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
@@ -3788,8 +4102,8 @@ func TestReadAfterClosedState(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3813,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(uint32(791+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(uint32(791+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3986,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func TestDefaultBufferSizes(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
// Check the default values.
@@ -4005,11 +4319,15 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{
- Min: 1,
- Default: tcp.DefaultSendBufferSize * 2,
- Max: tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20,
+ }
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
ep.Close()
@@ -4022,11 +4340,15 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{
- Min: 1,
- Default: tcp.DefaultReceiveBufferSize * 3,
- Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30,
+ }
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
ep.Close()
@@ -4041,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) {
func TestMinMaxBufferSizes(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
// Check the default values.
@@ -4053,22 +4375,28 @@ func TestMinMaxBufferSizes(t *testing.T) {
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
- // Set values below the min.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
+ // Set values below the min/2.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
}
@@ -4079,19 +4407,21 @@ func TestMinMaxBufferSizes(t *testing.T) {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
+ // Values above max are capped at max and then doubled.
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
+ // Values above max are capped at max and then doubled.
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
}
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}})
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
@@ -4124,16 +4454,15 @@ func TestBindToDeviceOption(t *testing.T) {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %s, want %v", err, nil)
- }
- if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %d, want %d", got, want)
+ t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
+ } else if bindToDevice != testAction.getBindToDevice {
+ t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
}
})
}
@@ -4141,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) {
func makeStack() (*stack.Stack, *tcpip.Error) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
- ipv6.NewProtocol(),
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
id := loopback.New()
@@ -4214,7 +4543,7 @@ func TestSelfConnect(t *testing.T) {
}
<-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ if err := ep.LastError(); err != nil {
t.Fatalf("Connect failed: %s", err)
}
@@ -4428,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(seqNum),
- checker.AckNum(790),
+ checker.TCPSeqNum(seqNum),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4520,8 +4849,8 @@ func TestStackSetCongestionControl(t *testing.T) {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
@@ -4553,12 +4882,12 @@ func TestStackAvailableCongestionControl(t *testing.T) {
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcpip.AvailableCongestionControlOption
+ var aCC tcpip.TCPAvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want)
}
}
@@ -4569,18 +4898,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) {
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcpip.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.TCPAvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcpip.AvailableCongestionControlOption
+ var cc tcpip.TCPAvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
+ t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err)
}
- if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want)
}
}
@@ -4609,20 +4938,20 @@ func TestEndpointSetCongestionControl(t *testing.T) {
var oldCC tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err)
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err)
}
if connected {
c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
}
- if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
- t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err)
+ if err := c.EP.SetSockOpt(&tc.cc); err != tc.err {
+ t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err)
+ t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err)
}
got, want := cc, oldCC
@@ -4634,7 +4963,7 @@ func TestEndpointSetCongestionControl(t *testing.T) {
want = tc.cc
}
if got != want {
- t.Fatalf("got congestion control: %v, want: %v", got, want)
+ t.Fatalf("got congestion control = %+v, want = %+v", got, want)
}
})
}
@@ -4644,8 +4973,8 @@ func TestEndpointSetCongestionControl(t *testing.T) {
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
opt := tcpip.CongestionControlOption("cubic")
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -4655,11 +4984,23 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ const keepAliveIdle = 100 * time.Millisecond
const keepAliveInterval = 3 * time.Second
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle)
+ if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err)
+ }
+ keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval)
+ if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err)
+ }
c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
- c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+ if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
+ t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
+ t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
+ }
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -4668,8 +5009,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4702,8 +5043,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4714,8 +5055,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
),
)
@@ -4740,8 +5081,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next-1)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(next-1)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4767,8 +5108,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -4808,7 +5149,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -4852,7 +5193,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -4925,12 +5266,12 @@ func TestListenBacklogFull(t *testing.T) {
defer c.WQ.EventUnregister(&we)
for i := 0; i < listenBacklog; i++ {
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -4942,7 +5283,7 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -4954,12 +5295,12 @@ func TestListenBacklogFull(t *testing.T) {
// Now a new handshake must succeed.
executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -4984,6 +5325,8 @@ func TestListenBacklogFull(t *testing.T) {
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ subnet := context.StackAddrWithPrefix.Subnet()
+ subnetBroadcastAddr := subnet.Broadcast()
tests := []struct {
name string
@@ -4991,53 +5334,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
dstAddr tcpip.Address
}{
{
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
+ name: "SourceUnspecified",
+ srcAddr: header.IPv4Any,
+ dstAddr: context.StackAddr,
},
{
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
+ name: "SourceBroadcast",
+ srcAddr: header.IPv4Broadcast,
+ dstAddr: context.StackAddr,
},
{
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
+ name: "SourceOurMulticast",
+ srcAddr: multicastAddr,
+ dstAddr: context.StackAddr,
},
{
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
+ name: "SourceOtherMulticast",
+ srcAddr: otherMulticastAddr,
+ dstAddr: context.StackAddr,
},
{
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
+ name: "DestUnspecified",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Any,
},
{
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
+ name: "DestBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Broadcast,
},
{
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
+ name: "DestOurMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: multicastAddr,
},
{
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
+ name: "DestOtherMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: otherMulticastAddr,
+ },
+ {
+ name: "SrcSubnetBroadcast",
+ srcAddr: subnetBroadcastAddr,
+ dstAddr: context.StackAddr,
+ },
+ {
+ name: "DestSubnetBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: subnetBroadcastAddr,
},
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5078,7 +5427,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5086,8 +5435,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
// non unicast IPv6 address are not accepted.
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
- otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
tests := []struct {
name string
@@ -5137,11 +5486,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5182,7 +5527,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5230,7 +5575,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5266,12 +5611,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5295,8 +5640,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err)
+ opt := tcpip.TCPSynRcvdCountThresholdOption(1)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create TCP endpoint.
@@ -5342,12 +5688,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5358,7 +5704,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
// Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != tcpip.ErrWouldBlock {
select {
case <-ch:
@@ -5407,7 +5753,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5428,8 +5774,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.AckNum(uint32(irs) + 1),
- checker.SeqNum(uint32(iss + 1)),
+ checker.TCPAckNum(uint32(irs) + 1),
+ checker.TCPSeqNum(uint32(iss + 1)),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5447,7 +5793,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
- newEP, _, err := c.EP.Accept()
+ newEP, _, err := c.EP.Accept(nil)
if err != nil && err != tcpip.ErrWouldBlock {
t.Fatalf("Accept failed: %s", err)
@@ -5462,7 +5808,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Wait for connection to be established.
select {
case <-ch:
- newEP, _, err = c.EP.Accept()
+ newEP, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5520,12 +5866,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Verify that there is only one acceptable connection at this point.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5590,12 +5936,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
defer c.WQ.EventUnregister(&we)
// Now check that there is one acceptable connections.
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ _, _, err = c.EP.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5643,12 +5989,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- aep, _, err = ep.Accept()
+ aep, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -5696,13 +6042,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
// Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Change the expected window scale to match the value needed for the
// maximum buffer size defined above.
@@ -5721,16 +6073,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
time.Sleep(latency)
rawEP.SendPacketWithTS([]byte{1}, tsVal)
- // Verify that the ACK has the expected window.
- wantRcvWnd := receiveBufferSize
- wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
- rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
- b := make([]byte, int(receiveBufferSize)*2)
- offset := 0
- payloadSize := receiveBufferSize - 1
+ payloadSize := receiveBufferSize * 2
+ b := make([]byte, int(payloadSize))
+
worker := (c.EP).(interface {
StopWork()
ResumeWork()
@@ -5739,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Stop the worker goroutine.
worker.StopWork()
- start := offset
- end := offset + payloadSize
+ start := 0
+ end := payloadSize / 2
packetsSent := 0
for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ packetEnd := start + mss
+ if start+mss > end {
+ packetEnd = end
+ }
+ rawEP.SendPacketWithTS(b[start:packetEnd], tsVal)
packetsSent++
}
@@ -5751,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// are waiting to be read.
worker.ResumeWork()
- // Since we read no bytes the window should goto zero till the
- // application reads some of the data.
- // Discard all intermediate acks except the last one.
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
- }
+ // Since we sent almost the full receive buffer worth of data (some may have
+ // been dropped due to segment overheads), we should get a zero window back.
+ pkt = c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(pkt).Payload())
+ gotRcvWnd := tcpHdr.WindowSize()
+ wantAckNum := tcpHdr.AckNumber()
+ if got, want := int(gotRcvWnd), 0; got != want {
+ t.Fatalf("got rcvWnd: %d, want: %d", got, want)
}
- rawEP.VerifyACKRcvWnd(0)
time.Sleep(25 * time.Millisecond)
- // Verify that sending more data when window is closed is dropped and
- // not acked.
+ // Verify that sending more data when receiveBuffer is exhausted.
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- // Verify that the stack sends us back an ACK with the sequence number
- // of the last packet sent indicating it was dropped.
- p := c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
- checker.Window(0),
- ))
-
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
@@ -5786,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Verify that we receive a non-zero window update ACK. When running
// under thread santizer this test can end up sending more than 1
// ack, 1 for the non-zero window
- p = c.GetPacket()
+ p := c.GetPacket()
checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ checker.TCPAckNum(uint32(wantAckNum)),
func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
- if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
+ // We use 10% here as the error margin upwards as the initial window we
+ // got was afer 1 segment was already in the receive buffer queue.
+ tolerance := 1.1
+ if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) {
+ t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance))
}
},
))
}
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
+// This test verifies that the advertised window is auto-tuned up as the
+// application is reading the data that is being received.
func TestReceiveBufferAutoTuning(t *testing.T) {
const mtu = 1500
const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
@@ -5812,26 +6160,33 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Enable Auto-tuning.
stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
}
// Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ {
+ opt := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
}
// Change the expected window scale to match the value needed for the
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
-
- wantRcvWnd := receiveBufferSize
+ tsVal := uint32(rawEP.TSVal)
+ rawEP.NextSeqNum--
+ rawEP.SendPacketWithTS(nil, tsVal)
+ rawEP.NextSeqNum++
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
scaleRcvWnd := func(rcvWnd int) uint16 {
return uint16(rcvWnd >> uint16(c.WindowScale))
}
@@ -5848,14 +6203,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
StopWork()
ResumeWork()
})
- tsVal := rawEP.TSVal
- // We are going to do our own computation of what the moderated receive
- // buffer should be based on sent/copied data per RTT and verify that
- // the advertised window by the stack matches our calculations.
- prevCopied := 0
- done := false
latency := 1 * time.Millisecond
- for i := 0; !done; i++ {
+ for i := 0; i < 5; i++ {
tsVal++
// Stop the worker goroutine.
@@ -5877,15 +6226,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Give 1ms for the worker to process the packets.
time.Sleep(1 * time.Millisecond)
- // Verify that the advertised window on the ACK is reduced by
- // the total bytes sent.
- expectedWnd := wantRcvWnd - totalSent
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
}
+ lastACK = pkt
+ }
+ if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want {
+ t.Fatalf("advertised window got: %d, want <= %d", got, want)
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
// Now read all the data from the endpoint and invoke the
// moderation API to allow for receive buffer auto-tuning
@@ -5910,35 +6264,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
-
if i == 0 {
// In the first iteration the receiver based RTT is not
// yet known as a result the moderation code should not
// increase the advertised window.
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
- prevCopied = totalCopied
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
} else {
- rttCopied := totalCopied
- if i == 1 {
- // The moderation code accumulates copied bytes till
- // RTT is established. So add in the bytes sent in
- // the first iteration to the total bytes for this
- // RTT.
- rttCopied += prevCopied
- // Now reset it to the initial value used by the
- // auto tuning logic.
- prevCopied = tcp.InitialCwnd * mss * 2
- }
- newWnd := rttCopied<<1 + 16*mss
- grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
- newWnd += (grow << 1)
- if newWnd > maxReceiveBufferSize {
- newWnd = maxReceiveBufferSize
- done = true
+ pkt := c.GetPacket()
+ curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
+ // If thew new current window is close maxReceiveBufferSize then terminate
+ // the loop. This can happen before all iterations are done due to timing
+ // differences when running the test.
+ if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 {
+ break
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
- wantRcvWnd = newWnd
- prevCopied = rttCopied
// Increase the latency after first two iterations to
// establish a low RTT value in the receiver since it
// only tracks the lowest value. This ensures that when
@@ -5951,6 +6290,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
offset += payloadSize
payloadSize *= 2
}
+ // Check that at the end of our iterations the receive window grew close to the maximum
+ // permissible size of maxReceiveBufferSize/2
+ if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want {
+ t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want)
+ }
+
}
func TestDelayEnabled(t *testing.T) {
@@ -5959,7 +6304,7 @@ func TestDelayEnabled(t *testing.T) {
checkDelayOption(t, c, false, false) // Delay is disabled by default.
for _, v := range []struct {
- delayEnabled tcp.DelayEnabled
+ delayEnabled tcpip.TCPDelayEnabled
wantDelayOption bool
}{
{delayEnabled: false, wantDelayOption: false},
@@ -5967,17 +6312,17 @@ func TestDelayEnabled(t *testing.T) {
} {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err)
}
checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
}
}
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) {
t.Helper()
- var gotDelayEnabled tcp.DelayEnabled
+ var gotDelayEnabled tcpip.TCPDelayEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
}
@@ -6009,24 +6354,27 @@ func TestTCPLingerTimeout(t *testing.T) {
tcpLingerTimeout time.Duration
want time.Duration
}{
- {"NegativeLingerTimeout", -123123, 0},
- {"ZeroLingerTimeout", 0, 0},
+ {"NegativeLingerTimeout", -123123, -1},
+ // Zero is treated same as the stack's default TCP_LINGER2 timeout.
+ {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout},
{"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
// Values > stack's TCPLingerTimeout are capped to the stack's
// value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
+ {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
- t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
+ v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)
+ if err := c.EP.SetSockOpt(&v); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err)
}
- var v tcpip.TCPLingerTimeoutOption
+
+ v = 0
if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
+ t.Fatalf("GetSockOpt(&%T) = %s", v, err)
}
if got, want := time.Duration(v), tc.want; got != want {
- t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
+ t.Fatalf("got linger timeout = %s, want = %s", got, want)
}
})
}
@@ -6080,12 +6428,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6099,8 +6447,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6117,8 +6465,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now send a RST and this should be ignored and not
@@ -6146,8 +6494,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6199,12 +6547,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6218,8 +6566,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6236,8 +6584,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Out of order ACK should generate an immediate ACK in
@@ -6253,8 +6601,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6306,12 +6654,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6325,8 +6673,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6343,8 +6691,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Send a SYN request w/ sequence number lower than
@@ -6389,12 +6737,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.SendPacket(nil, ackHeaders)
// Try to accept the connection.
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6412,8 +6760,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
// after 5 seconds in TIME_WAIT state.
tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
}
want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
@@ -6462,12 +6811,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6481,8 +6830,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6499,8 +6848,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
time.Sleep(2 * time.Second)
@@ -6514,8 +6863,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Sleep for 4 seconds so at this point we are 1 second past the
@@ -6543,8 +6892,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
@@ -6562,8 +6911,9 @@ func TestTCPCloseWithData(t *testing.T) {
// Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
// after 5 seconds in TIME_WAIT state.
tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
}
wq := &waiter.Queue{}
@@ -6611,12 +6961,12 @@ func TestTCPCloseWithData(t *testing.T) {
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
t.Fatalf("Accept failed: %s", err)
}
@@ -6642,8 +6992,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now write a few bytes and then close the endpoint.
@@ -6661,8 +7011,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -6676,8 +7026,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
- checker.AckNum(uint32(iss+2)),
+ checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.TCPAckNum(uint32(iss+2)),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
// First send a partial ACK.
@@ -6722,8 +7072,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -6743,7 +7093,10 @@ func TestTCPUserTimeout(t *testing.T) {
// expired.
initRTO := 1 * time.Second
userTimeout := initRTO / 2
- c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+ v := tcpip.TCPUserTimeoutOption(userTimeout)
+ if err := c.EP.SetSockOpt(&v); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err)
+ }
// Send some data and wait before ACKing it.
view := buffer.NewView(3)
@@ -6756,8 +7109,8 @@ func TestTCPUserTimeout(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -6791,8 +7144,8 @@ func TestTCPUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -6817,18 +7170,31 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+ const keepAliveIdle = 100 * time.Millisecond
const keepAliveInterval = 3 * time.Second
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10)
- c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+ keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle)
+ if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err)
+ }
+ keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval)
+ if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err)
+ }
+ if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
+ t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil {
+ t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err)
+ }
// Set userTimeout to be the duration to be 1 keepalive
// probes. Which means that after the first probe is sent
// the second one should cause the connection to be
// closed due to userTimeout being hit.
- userTimeout := 1 * keepAliveInterval
- c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+ userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval)
+ if err := c.EP.SetSockOpt(&userTimeout); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err)
+ }
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
@@ -6840,8 +7206,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -6866,8 +7232,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -6883,9 +7249,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
}
}
-func TestIncreaseWindowOnReceive(t *testing.T) {
+func TestIncreaseWindowOnRead(t *testing.T) {
// This test ensures that the endpoint sends an ack,
- // after recv() when the window grows to more than 1 MSS.
+ // after read() when the window grows by more than 1 MSS.
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -6894,10 +7260,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf
+ remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -6910,46 +7275,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Break once the window drops below defaultMTU/2
+ if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 {
+ break
+ }
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
- // We now have < 1 MSS in the buffer space. Read the data! An
- // ack should be sent in response to that. The window was not
- // zero, but it grew to larger than MSS.
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %s", err)
+ // We now have < 1 MSS in the buffer space. Read at least > 2 MSS
+ // worth of data as receive buffer space
+ read := 0
+ // defaultMTU is a good enough estimate for the MSS used for this
+ // connection.
+ for read < defaultMTU*2 {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ read += len(v)
}
- // After reading two packets, we surely crossed MSS. See the ack:
+ // After reading > MSS worth of data, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -6966,10 +7328,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf
+ remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -6983,38 +7344,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
sent += len(data)
remain -= len(data)
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowLessThanEq(0xffff),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
- // After reading two packets, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7035,14 +7387,15 @@ func TestTCPDeferAccept(t *testing.T) {
}
const tcpDeferAccept = 1 * time.Second
- if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
+ tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
+ if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Send data. This should result in an acceptable endpoint.
@@ -7058,14 +7411,14 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give a bit of time for the socket to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
@@ -7073,8 +7426,8 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestTCPDeferAcceptTimeout(t *testing.T) {
@@ -7092,14 +7445,15 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
}
const tcpDeferAccept = 1 * time.Second
- if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
+ tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
+ if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err)
}
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7110,7 +7464,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
// Send data. This should result in an acceptable endpoint.
c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
@@ -7126,14 +7480,14 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give sometime for the endpoint to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
+ aep, _, err := c.EP.Accept(nil)
if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
}
aep.Close()
@@ -7142,8 +7496,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestResetDuringClose(t *testing.T) {
@@ -7168,8 +7522,8 @@ func TestResetDuringClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(irs.Add(1))),
- checker.AckNum(uint32(iss.Add(5)))))
+ checker.TCPSeqNum(uint32(irs.Add(1))),
+ checker.TCPAckNum(uint32(iss.Add(5)))))
// Close in a separate goroutine so that we can trigger
// a race with the RST we send below. This should not
@@ -7199,3 +7553,65 @@ func TestResetDuringClose(t *testing.T) {
wg.Wait()
}
+
+func TestStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
+ }
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+}
+
+func TestSetStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ testCases := []struct {
+ v int
+ err *tcpip.Error
+ }{
+ {int(tcpip.TCPTimeWaitReuseDisabled), nil},
+ {int(tcpip.TCPTimeWaitReuseGlobal), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue},
+ {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue},
+ }
+
+ for _, tc := range testCases {
+ opt := tcpip.TCPTimeWaitReuseOption(tc.v)
+ err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)
+ if got, want := err, tc.err; got != want {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err)
+ }
+ if tc.err != nil {
+ continue
+ }
+
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
+ }
+
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+ }
+}
+
+// generateRandomPayload generates a random byte slice of the specified length
+// causing a fatal test failure if it is unable to do so.
+func generateRandomPayload(t *testing.T, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return buf
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 8edbff964..0f9ed06cd 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
defer c.Cleanup()
if cookieEnabled {
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -158,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(true, 0, tsVal+1),
),
@@ -180,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -192,8 +194,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
defer c.Cleanup()
if cookieEnabled {
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ var opt tcpip.TCPSynRcvdCountThresholdOption
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -217,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(false, 0, 0),
),
@@ -235,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of
+ // that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 06fde2a79..faf51ef95 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -53,11 +53,11 @@ const (
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
@@ -73,6 +73,18 @@ const (
testInitialSequenceNumber = 789
)
+// StackAddrWithPrefix is StackAddr with its associated prefix length.
+var StackAddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackAddr,
+ PrefixLen: 24,
+}
+
+// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
+var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackV6Addr,
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+}
+
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -133,30 +145,39 @@ type Context struct {
// WindowScale is the expected window scale in SYN packets sent by
// the stack.
WindowScale uint8
+
+ // RcvdWindowScale is the actual window scale sent by the stack in
+ // SYN/SYN-ACK.
+ RcvdWindowScale uint8
}
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err)
}
// Increase minimum RTO in tests to avoid test flakes due to early
// retransmit in case the test executors are overloaded and cause timers
// to fire earlier than expected.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil {
- t.Fatalf("failed to set stack-wide minRTO: %s", err)
+ minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
}
// Some of the congestion control tests send up to 640 packets, we so
@@ -179,12 +200,20 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -202,7 +231,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
@@ -236,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) {
c.CheckNoPacketTimeout(errMsg, 1*time.Second)
}
-// GetPacket reads a packet from the link layer endpoint and verifies
+// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies
// that it is an IPv4 packet with the expected source and destination
-// addresses. It will fail with an error if no packet is received for
-// 2 seconds.
-func (c *Context) GetPacket() []byte {
+// addresses. If no packet is received in the specified timeout it will return
+// nil.
+func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
c.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
- c.t.Fatalf("Packet wasn't written out")
return nil
}
@@ -255,8 +283,16 @@ func (c *Context) GetPacket() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
@@ -266,6 +302,21 @@ func (c *Context) GetPacket() []byte {
return b
}
+// GetPacket reads a packet from the link layer endpoint and verifies
+// that it is an IPv4 packet with the expected source and destination
+// addresses.
+func (c *Context) GetPacket() []byte {
+ c.t.Helper()
+
+ p := c.GetPacketWithTimeout(5 * time.Second)
+ if p == nil {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ return p
+}
+
// GetPacketNonBlocking reads a packet from the link layer endpoint
// and verifies that it is an IPv4 packet with the expected source
// and destination address. If no packet is available it will return
@@ -282,15 +333,23 @@ func (c *Context) GetPacketNonBlocking() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
-func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
+func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
if len(buf) > maxTotalSize {
@@ -316,9 +375,10 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
@@ -372,26 +432,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: s,
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: c.BuildSegment(payload, h),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendAck sends an ACK packet.
@@ -441,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.PayloadLen(size+header.TCPMinimumSize+optlen),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -468,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -512,9 +575,8 @@ func (c *Context) GetV6Packet() []byte {
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
return b
@@ -564,9 +626,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
})
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt)
}
// CreateConnected creates a connected TCP endpoint.
@@ -607,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
+ synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */)
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
c.SendPacket(nil, &Headers{
@@ -624,15 +688,15 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
// Wait for connection to be established.
select {
case <-notifyCh:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ if err := c.EP.LastError(); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -642,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ c.RcvdWindowScale = uint8(synOpts.WS)
c.Port = tcpHdr.SourcePort()
}
@@ -713,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
}
-// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
-// tsVal.
-func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches
+// the provided tsVal as well as returns the original packet.
+func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte {
+ r.C.t.Helper()
// Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPTimestampChecker(true, 0, tsVal),
),
)
@@ -731,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
opts := tcpSeg.ParsedOptions()
r.RecentTS = opts.TSVal
+ return ackPacket
+}
+
+// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
+// tsVal.
+func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+ r.C.t.Helper()
+ _ = r.VerifyAndReturnACKWithTS(tsVal)
}
// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
// matches the provided rcvWnd.
func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
+ r.C.t.Helper()
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
- checker.Window(rcvWnd),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
+ checker.TCPWindow(rcvWnd),
),
)
}
@@ -762,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPSACKBlockChecker(sackBlocks),
),
)
@@ -855,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
tcpCheckers := []checker.TransportChecker{
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS) + 1),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPSeqNum(uint32(c.IRS) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
}
// Verify that tsEcr of ACK packet is wantOptions.TSVal if the
@@ -876,8 +951,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -892,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Mark in context that timestamp option is enabled for this endpoint.
c.TimeStampEnabled = true
-
+ c.RcvdWindowScale = uint8(synOptions.WS)
return &RawEndpoint{
C: c,
SrcPort: tcpSeg.DestinationPort(),
@@ -943,12 +1017,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- c.EP, _, err = ep.Accept()
+ c.EP, _, err = ep.Accept(nil)
if err != nil {
c.t.Fatalf("Accept failed: %v", err)
}
@@ -985,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
offset += header.EncodeMSSOption(uint32(maxPayload), opts)
@@ -1023,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// are present.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
+ rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */)
c.IRS = seqnum.Value(tcp.SequenceNumber())
tcpCheckers := []checker.TransportChecker{
checker.SrcPort(StackPort),
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
}
@@ -1072,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// Send ACK.
c.SendPacket(nil, ackHeaders)
+ c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
c.Port = StackPort
return &RawEndpoint{
@@ -1091,7 +1168,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
// for the Stack in the context.
func (c *Context) SACKEnabled() bool {
- var v tcp.SACKEnabled
+ var v tcpip.TCPSACKEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return false
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 12bc1b5b5..558b06df0 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
return st
}
+// State returns the current state of the TCB.
+func (t *TCB) State() Result {
+ return t.state
+}
+
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index b5d2d0ba6..c78549424 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index cae29fbff..d57ed5d79 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -139,7 +139,7 @@ type endpoint struct {
// multicastMemberships that need to be remvoed when the endpoint is
// closed. Protected by the mu mutex.
- multicastMemberships []multicastMembership
+ multicastMemberships map[multicastMembership]struct{}
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -154,6 +154,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// +stateify savable
@@ -182,12 +185,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// TTL=1.
//
// Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastLoop: true,
- rcvBufSizeMax: 32 * 1024,
- sndBufSizeMax: 32 * 1024,
- state: StateInitial,
- uniqueID: s.UniqueID(),
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ state: StateInitial,
+ uniqueID: s.UniqueID(),
}
// Override with stack defaults.
@@ -209,7 +213,7 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
-func (e *endpoint) takeLastError() *tcpip.Error {
+func (e *endpoint) LastError() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
@@ -237,10 +241,10 @@ func (e *endpoint) Close() {
e.boundPortFlags = ports.Flags{}
}
- for _, mem := range e.multicastMemberships {
+ for mem := range e.multicastMemberships {
e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
- e.multicastMemberships = nil
+ e.multicastMemberships = make(map[multicastMembership]struct{})
// Close the receive list and drain it.
e.rcvMu.Lock()
@@ -268,7 +272,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {}
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if err := e.takeLastError(); err != nil {
+ if err := e.LastError(); err != nil {
return buffer.View{}, tcpip.ControlMessages{}, err
}
@@ -411,7 +415,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- if err := e.takeLastError(); err != nil {
+ if err := e.LastError(); err != nil {
return 0, nil, err
}
@@ -483,10 +487,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
- if to.Addr == header.IPv4Broadcast && !e.broadcast {
- return 0, nil, tcpip.ErrBroadcastDisabled
- }
-
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -503,6 +503,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
resolve = route.Resolve
}
+ if !e.broadcast && route.IsOutboundBroadcast() {
+ return 0, nil, tcpip.ErrBroadcastDisabled
+ }
+
if route.IsResolutionRequired() {
if ch, err := resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
@@ -612,6 +616,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -676,9 +687,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case tcpip.MulticastInterfaceOption:
+ case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
defer e.mu.Unlock()
@@ -714,7 +725,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastNICID = nic
e.multicastAddr = addr
- case tcpip.AddMembershipOption:
+ case *tcpip.AddMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
return tcpip.ErrInvalidOptionValue
}
@@ -745,19 +756,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- for _, mem := range e.multicastMemberships {
- if mem == memToInsert {
- return tcpip.ErrPortInUse
- }
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return tcpip.ErrPortInUse
}
if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
- e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+ e.multicastMemberships[memToInsert] = struct{}{}
- case tcpip.RemoveMembershipOption:
+ case *tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
return tcpip.ErrInvalidOptionValue
}
@@ -779,18 +788,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
- memToRemoveIndex := -1
e.mu.Lock()
defer e.mu.Unlock()
- for i, mem := range e.multicastMemberships {
- if mem == memToRemove {
- memToRemoveIndex = i
- break
- }
- }
- if memToRemoveIndex == -1 {
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
return tcpip.ErrBadLocalAddress
}
@@ -798,17 +800,24 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return err
}
- e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
- e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ delete(e.multicastMemberships, memToRemove)
- case tcpip.BindToDeviceOption:
- id := tcpip.NICID(v)
+ case *tcpip.BindToDeviceOption:
+ id := tcpip.NICID(*v)
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
}
return nil
}
@@ -906,6 +915,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
v := int(e.multicastTTL)
@@ -946,10 +959,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
- case tcpip.ErrorOption:
- return e.takeLastError()
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
@@ -963,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
+ case *tcpip.LingerOption:
+ e.mu.RLock()
+ *o = e.linger
+ e.mu.RUnlock()
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -972,13 +988,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
- // Allocate a buffer for the UDP header.
- hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: data,
+ })
+ pkt.Owner = owner
- // Initialize the header.
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
@@ -1005,12 +1025,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
Protocol: ProtocolNumber,
TTL: ttl,
TOS: tos,
- }, &stack.PacketBuffer{
- Header: hdr,
- Data: data,
- TransportHeader: buffer.View(udp),
- Owner: owner,
- }); err != nil {
+ }, pkt); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
@@ -1208,13 +1223,13 @@ func (*endpoint) Listen(int) *tcpip.Error {
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
return id, e.bindToDevice, err
}
@@ -1354,11 +1369,27 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
return result
}
+// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
+// On IPv4, UDP checksum is optional, and a zero value means the transmitter
+// omitted the checksum generation (RFC768).
+// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool {
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ return hdr.CalculateChecksum(xsum) == 0xffff
+ }
+ return true
+}
+
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr := header.UDP(pkt.TransportHeader)
+ hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
@@ -1366,28 +1397,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
- // Verify checksum unless RX checksum offload is enabled.
- // On IPv4, UDP checksum is optional, and a zero value means
- // the transmitter omitted the checksum generation (RFC768).
- // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
- (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
- for _, v := range pkt.Data.Views() {
- xsum = header.Checksum(v, xsum)
- }
- if hdr.CalculateChecksum(xsum) != 0xffff {
- // Checksum Error.
- e.stack.Stats().UDP.ChecksumErrors.Increment()
- e.stats.ReceiveErrors.ChecksumErrors.Increment()
- return
- }
+ if !verifyChecksum(r, hdr, pkt) {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
}
- e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
@@ -1420,15 +1440,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Save any useful information from the network header to the packet.
switch r.NetProto {
case header.IPv4ProtocolNumber:
- packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
- packet.packetInfo.LocalAddr = r.LocalAddress
- packet.packetInfo.DestinationAddr = r.RemoteAddress
- packet.packetInfo.NIC = r.NICID()
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
case header.IPv6ProtocolNumber:
- packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
+ // address. packetInfo.LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.LocalAddress
+ packet.packetInfo.NIC = r.NICID()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 851e6b635..858c99a45 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- for _, m := range e.multicastMemberships {
+ for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 0e7464e3a..da5b1deb2 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package udp contains the implementation of the UDP transport protocol. To use
-// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.NewProtocol() as one of the
-// transport protocols when calling stack.New(). Then endpoints can be created
-// by passing udp.ProtocolNumber as the transport protocol number when calling
-// Stack.NewEndpoint().
+// Package udp contains the implementation of the UDP transport protocol.
package udp
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/waiter"
@@ -49,6 +45,7 @@ const (
)
type protocol struct {
+ stack *stack.Stack
}
// Number returns the udp protocol number.
@@ -57,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid udp packet size.
@@ -79,131 +76,30 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
-// HandleUnknownDestinationPacket handles packets targeted at this protocol but
-// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
- hdr := header.UDP(pkt.TransportHeader)
+// HandleUnknownDestinationPacket handles packets that are targeted at this
+// protocol but don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
- // Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
- }
- // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
-
- // Only send ICMP error if the address is not a multicast/broadcast
- // v4/v6 address or the source is not the unspecified address.
- //
- // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
- return true
+ return stack.UnknownDestinationPacketMalformed
}
- // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
- // Unreachable messages with code:
- //
- // 2 (Protocol Unreachable), when the designated transport protocol
- // is not supported; or
- //
- // 3 (Port Unreachable), when the designated transport protocol
- // (e.g., UDP) is unable to demultiplex the datagram but has no
- // protocol mechanism to inform the sender.
- switch len(id.LocalAddress) {
- case header.IPv4AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
- return true
- }
- // As per RFC 1812 Section 4.3.2.3
- //
- // ICMP datagram SHOULD contain as much of the original
- // datagram as possible without the length of the ICMP
- // datagram exceeding 576 bytes
- //
- // NOTE: The above RFC referenced is different from the original
- // recommendation in RFC 1122 where it mentioned that at least 8
- // bytes of the payload must be included. Today linux and other
- // systems implement the] RFC1812 definition and not the original
- // RFC 1122 requirement.
- mtu := int(r.MTU())
- if mtu > header.IPv4MinimumProcessableDatagramSize {
- mtu = header.IPv4MinimumProcessableDatagramSize
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
- available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
-
- // The buffers used by pkt may be used elsewhere in the system.
- // For example, a raw or packet socket may use what UDP
- // considers an unreachable destination. Thus we deep copy pkt
- // to prevent multiple ownership and SR errors.
- newHeader := append(buffer.View(nil), pkt.NetworkHeader...)
- newHeader = append(newHeader, pkt.TransportHeader...)
- payload := newHeader.ToVectorisedView()
- payload.AppendView(pkt.Data.ToView())
- payload.CapLength(payloadLen)
-
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4DstUnreachable)
- pkt.SetCode(header.ICMPv4PortUnreachable)
- pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- TransportHeader: buffer.View(pkt),
- Data: payload,
- })
-
- case header.IPv6AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
- return true
- }
-
- // As per RFC 4443 section 2.4
- //
- // (c) Every ICMPv6 error message (type < 128) MUST include
- // as much of the IPv6 offending (invoking) packet (the
- // packet that caused the error) as possible without making
- // the error message packet exceed the minimum IPv6 MTU
- // [IPv6].
- mtu := int(r.MTU())
- if mtu > header.IPv6MinimumMTU {
- mtu = header.IPv6MinimumMTU
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
- available := int(mtu) - headerLen
- payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
- payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader})
- payload.Append(pkt.Data)
- payload.CapLength(payloadLen)
-
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
- pkt.SetType(header.ICMPv6DstUnreachable)
- pkt.SetCode(header.ICMPv6PortUnreachable)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- TransportHeader: buffer.View(pkt),
- Data: payload,
- })
+ if !verifyChecksum(r, hdr, pkt) {
+ r.Stack().Stats().UDP.ChecksumErrors.Increment()
+ return stack.UnknownDestinationPacketMalformed
}
- return true
+
+ return stack.UnknownDestinationPacketUnhandled
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements stack.TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
@@ -215,17 +111,10 @@ func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
- // Packet is too small
- return false
- }
- pkt.TransportHeader = h
- pkt.Data.TrimFront(header.UDPMinimumSize)
- return true
+ return parse.UDP(pkt)
}
// NewProtocol returns a UDP transport protocol.
-func NewProtocol() stack.TransportProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index db59eb5a0..cddedb686 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -83,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -117,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -276,8 +294,8 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
return newDualTestContextWithOptions(t, mtu, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
}
@@ -370,8 +388,12 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
+ c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
h := flow.header4Tuple(outgoing)
checkers = append(
@@ -385,21 +407,38 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
}
// injectPacket creates a packet of the given flow and with the given payload,
-// and injects it into the link endpoint.
-func (c *testContext) injectPacket(flow testFlow, payload []byte) {
+// and injects it into the link endpoint. If badChecksum is true, the packet has
+// a bad checksum in the UDP header.
+func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) {
c.t.Helper()
h := flow.header4Tuple(incoming)
if flow.isV4() {
buf := c.buildV4Packet(payload, &h)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ if badChecksum {
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+ }
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
} else {
buf := c.buildV6Packet(payload, &h)
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ if badChecksum {
+ // Invalidate the UDP header checksum field (Unlike IPv4, zero is
+ // a valid checksum value for IPv6 so no need to avoid it).
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+ }
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
}
@@ -493,8 +532,8 @@ func newMinPayload(minSize int) []byte {
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}})
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
@@ -504,7 +543,7 @@ func TestBindToDeviceOption(t *testing.T) {
opts := stack.NICOptions{Name: "my_device"}
if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
}
// nicIDPtr is used instead of taking the address of NICID literals, which is
@@ -528,16 +567,15 @@ func TestBindToDeviceOption(t *testing.T) {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
- }
- if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %d, want %d", got, want)
+ t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
+ } else if bindToDevice != testAction.getBindToDevice {
+ t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
}
})
}
@@ -551,7 +589,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
c.t.Helper()
payload := newPayload()
- c.injectPacket(flow, payload)
+ c.injectPacket(flow, payload, false)
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
@@ -593,12 +631,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr)
+ c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
- c.t.Fatalf("bad payload: got %x, want %x", v, payload)
+ c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
@@ -659,7 +697,7 @@ func TestBindReservedPort(t *testing.T) {
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
}
}
@@ -672,7 +710,7 @@ func TestBindReservedPort(t *testing.T) {
// We can't bind ipv4-any on the port reserved by the connected endpoint
// above, since the endpoint is dual-stack.
if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
@@ -769,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: tt.handleLocal,
})
defer c.cleanup()
@@ -786,16 +824,16 @@ func TestV4ReadSelfSource(t *testing.T) {
h.srcAddr = h.dstAddr
buf := c.buildV4Packet(payload, &h)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
}
})
}
@@ -836,8 +874,8 @@ func TestReadOnBoundToMulticast(t *testing.T) {
// Join multicast group.
ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
}
// Check that we receive multicast packets but not unicast or broadcast
@@ -872,6 +910,24 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -1237,6 +1293,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
}
}
+func TestReadIPPacketInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedLocalAddr tcpip.Address
+ expectedDestAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedLocalAddr: stackAddr,
+ expectedDestAddr: stackAddr,
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastAddr,
+ expectedDestAddr: multicastAddr,
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: broadcastAddr,
+ expectedDestAddr: broadcastAddr,
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedLocalAddr: stackV6Addr,
+ expectedDestAddr: stackV6Addr,
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastV6Addr,
+ expectedDestAddr: multicastV6Addr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
+ }
+ }
+
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
+ t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
+ }
+
+ testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: 1,
+ LocalAddr: test.expectedLocalAddr,
+ DestinationAddr: test.expectedDestAddr,
+ }))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1275,6 +1430,30 @@ func TestNoChecksum(t *testing.T) {
}
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct{}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return nil
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1292,19 +1471,19 @@ func TestTTL(t *testing.T) {
if flow.isMulticast() {
wantTTL = multicastTTL
} else {
- var p stack.NetworkProtocol
+ var p stack.NetworkProtocolFactory
+ var n tcpip.NetworkProtocolNumber
if flow.isV4() {
- p = ipv4.NewProtocol()
+ p = ipv4.NewProtocol
+ n = ipv4.ProtocolNumber
} else {
- p = ipv6.NewProtocol()
- }
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
- if err != nil {
- t.Fatal(err)
+ p = ipv6.NewProtocol
+ n = ipv6.ProtocolNumber
}
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{p},
+ })
+ ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil)
wantTTL = ep.DefaultTTL()
ep.Close()
}
@@ -1328,21 +1507,6 @@ func TestSetTTL(t *testing.T) {
c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
- var p stack.NetworkProtocol
- if flow.isV4() {
- p = ipv4.NewProtocol()
- } else {
- p = ipv6.NewProtocol()
- }
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
- if err != nil {
- t.Fatal(err)
- }
- ep.Close()
-
testWrite(c, flow, checker.TTL(wantTTL))
})
}
@@ -1365,7 +1529,7 @@ func TestSetTOS(t *testing.T) {
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
+ c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
}
if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
@@ -1526,19 +1690,17 @@ func TestMulticastInterfaceOption(t *testing.T) {
}
}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %s", err)
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
}
// Verify multicast interface addr and NIC were set correctly.
// Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %s", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
+ c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
+ } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
+ c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
}
})
}
@@ -1562,21 +1724,33 @@ func TestV4UnknownDestination(t *testing.T) {
// so that the final generated IPv4 packet is larger than
// header.IPv4MinimumProcessableDatagramSize.
largePayload bool
+ // badChecksum if true, will set an invalid checksum in the
+ // header.
+ badChecksum bool
}{
- {unicastV4, true, false},
- {unicastV4, true, true},
- {multicastV4, false, false},
- {multicastV4, false, true},
- {broadcast, false, false},
- {broadcast, false, true},
- }
+ {unicastV4, true, false, false},
+ {unicastV4, true, true, false},
+ {unicastV4, false, false, true},
+ {unicastV4, false, true, true},
+ {multicastV4, false, false, false},
+ {multicastV4, false, true, false},
+ {broadcast, false, false, false},
+ {broadcast, false, true, false},
+ }
+ checksumErrors := uint64(0)
for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
payload := newPayload()
if tc.largePayload {
payload = newMinPayload(576)
}
- c.injectPacket(tc.flow, payload)
+ c.injectPacket(tc.flow, payload, tc.badChecksum)
+ if tc.badChecksum {
+ checksumErrors++
+ if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
+ t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ }
if !tc.icmpRequired {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@@ -1595,9 +1769,8 @@ func TestV4UnknownDestination(t *testing.T) {
return
}
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1607,16 +1780,26 @@ func TestV4UnknownDestination(t *testing.T) {
checker.ICMPv4Type(header.ICMPv4DstUnreachable),
checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ // We need to compare the included data part of the UDP packet that is in
+ // the ICMP packet with the matching original data.
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
wantLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ // To work out the data size we need to simulate what the sender would
+ // have done. The wanted size is the total available minus the sum of
+ // the headers in the UDP AND ICMP packets, given that we know the test
+ // had only a minimal IP header but the ICMP sender will have allowed
+ // for a maximally sized packet header.
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+
}
- // In case of large payloads the IP packet may be truncated. Update
+ // In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ // Add back the two headers within the payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
if got, want := len(origDgram.Payload()), wantLen; got != want {
@@ -1642,19 +1825,31 @@ func TestV6UnknownDestination(t *testing.T) {
// largePayload if true will result in a payload large enough to
// create an IPv6 packet > header.IPv6MinimumMTU bytes.
largePayload bool
+ // badChecksum if true, will set an invalid checksum in the
+ // header.
+ badChecksum bool
}{
- {unicastV6, true, false},
- {unicastV6, true, true},
- {multicastV6, false, false},
- {multicastV6, false, true},
- }
+ {unicastV6, true, false, false},
+ {unicastV6, true, true, false},
+ {unicastV6, false, false, true},
+ {unicastV6, false, true, true},
+ {multicastV6, false, false, false},
+ {multicastV6, false, true, false},
+ }
+ checksumErrors := uint64(0)
for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
payload := newPayload()
if tc.largePayload {
payload = newMinPayload(1280)
}
- c.injectPacket(tc.flow, payload)
+ c.injectPacket(tc.flow, payload, tc.badChecksum)
+ if tc.badChecksum {
+ checksumErrors++
+ if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
+ t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ }
if !tc.icmpRequired {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@@ -1673,9 +1868,8 @@ func TestV6UnknownDestination(t *testing.T) {
return
}
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1721,12 +1915,14 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
payload := newPayload()
h := unicastV6.header4Tuple(incoming)
buf := c.buildV6Packet(payload, &h)
- // Invalidate the packet length field in the UDP header by adding one.
+
+ // Invalidate the UDP header length field.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.SetLength(u.Length() + 1)
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
@@ -1779,74 +1975,38 @@ func TestShortHeader(t *testing.T) {
copy(buf[header.IPv6MinimumSize:], udpHdr)
// Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
}
}
-// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// TestBadChecksumErrors verifies if a checksum error is detected,
// global and endpoint stats are incremented.
-func TestIncrementChecksumErrorsV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv4.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- buf := c.buildV4Packet(payload, &h)
- // Invalidate the checksum field in the UDP header by adding one.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.SetChecksum(u.Checksum() + 1)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestIncrementChecksumErrorsV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+func TestBadChecksumErrors(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
+ c.createEndpoint(flow.sockProto())
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
- // Invalidate the checksum field in the UDP header by adding one.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetChecksum(u.Checksum() + 1)
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ payload := newPayload()
+ c.injectPacket(flow, payload, true /* badChecksum */)
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
}
}
@@ -1865,11 +2025,12 @@ func TestPayloadModifiedV4(t *testing.T) {
payload := newPayload()
h := unicastV4.header4Tuple(incoming)
buf := c.buildV4Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
const want = 1
if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
@@ -1895,11 +2056,12 @@ func TestPayloadModifiedV6(t *testing.T) {
payload := newPayload()
h := unicastV6.header4Tuple(incoming)
buf := c.buildV6Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
const want = 1
if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
@@ -1928,9 +2090,9 @@ func TestChecksumZeroV4(t *testing.T) {
// Set the checksum field in the UDP header to zero.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
const want = 0
if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
@@ -1959,9 +2121,9 @@ func TestChecksumZeroV6(t *testing.T) {
// Set the checksum field in the UDP header to zero.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
const want = 1
if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
@@ -2059,3 +2221,193 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn
c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
}
}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const nicID1 = 1
+
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ requiresBroadcastOpt bool
+ }{
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ requiresBroadcastOpt: false,
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ // TODO(gvisor.dev/issue/3938): Once we support marking a route as
+ // broadcast, this test should require the broadcast option to be set.
+ requiresBroadcastOpt: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
+ }
+ defer ep.Close()
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ to := tcpip.FullAddress{
+ Addr: test.remoteAddr,
+ Port: 80,
+ }
+ opts := tcpip.WriteOptions{To: &to}
+ expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ if !test.requiresBroadcastOpt {
+ expectedErrWithoutBcastOpt = nil
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != nil {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+ })
+ }
+}