summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/buffer/BUILD4
-rw-r--r--pkg/tcpip/checker/checker.go60
-rw-r--r--pkg/tcpip/faketime/BUILD24
-rw-r--r--pkg/tcpip/faketime/faketime.go (renamed from pkg/tcpip/stack/fake_time_test.go)137
-rw-r--r--pkg/tcpip/faketime/faketime_test.go95
-rw-r--r--pkg/tcpip/header/icmpv4.go50
-rw-r--r--pkg/tcpip/header/icmpv6.go35
-rw-r--r--pkg/tcpip/header/ipv4.go5
-rw-r--r--pkg/tcpip/network/arp/arp.go6
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go180
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go187
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go125
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go3
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go97
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go2
-rw-r--r--pkg/tcpip/network/testutil/BUILD7
-rw-r--r--pkg/tcpip/network/testutil/testutil.go102
-rw-r--r--pkg/tcpip/stack/BUILD3
-rw-r--r--pkg/tcpip/stack/forwarder_test.go6
-rw-r--r--pkg/tcpip/stack/ndp.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go12
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go67
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go3
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go47
-rw-r--r--pkg/tcpip/stack/nic.go34
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go29
-rw-r--r--pkg/tcpip/stack/registration.go45
-rw-r--r--pkg/tcpip/stack/stack.go72
-rw-r--r--pkg/tcpip/stack/stack_test.go13
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/tcpip.go11
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go4
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/connect.go20
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go16
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go165
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go13
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go50
-rw-r--r--pkg/tcpip/transport/tcp/segment.go45
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go52
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go893
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go19
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go101
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go1
-rw-r--r--pkg/tcpip/transport/udp/protocol.go117
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go4
52 files changed, 1844 insertions, 1144 deletions
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index 563bc78ea..c326fab54 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -14,6 +14,8 @@ go_library(
go_test(
name = "buffer_test",
size = "small",
- srcs = ["view_test.go"],
+ srcs = [
+ "view_test.go",
+ ],
library = ":buffer",
)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index b769094dc..19627fa9b 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -339,7 +339,7 @@ func NoChecksum(noChecksum bool) TransportChecker {
udp, ok := h.(header.UDP)
if !ok {
- return
+ t.Fatalf("UDP header not found in h: %T", h)
}
if b := udp.Checksum() == 0; b != noChecksum {
@@ -348,14 +348,14 @@ func NoChecksum(noChecksum bool) TransportChecker {
}
}
-// SeqNum creates a checker that checks the sequence number.
-func SeqNum(seq uint32) TransportChecker {
+// TCPSeqNum creates a checker that checks the sequence number.
+func TCPSeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.SequenceNumber(); s != seq {
@@ -364,14 +364,14 @@ func SeqNum(seq uint32) TransportChecker {
}
}
-// AckNum creates a checker that checks the ack number.
-func AckNum(seq uint32) TransportChecker {
+// TCPAckNum creates a checker that checks the ack number.
+func TCPAckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if s := tcp.AckNumber(); s != seq {
@@ -380,18 +380,52 @@ func AckNum(seq uint32) TransportChecker {
}
}
-// Window creates a checker that checks the tcp window.
-func Window(window uint16) TransportChecker {
+// TCPWindow creates a checker that checks the tcp window.
+func TCPWindow(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in hdr : %T", h)
}
if w := tcp.WindowSize(); w != window {
- t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
+ t.Errorf("Bad window, got %d, want %d", w, window)
+ }
+ }
+}
+
+// TCPWindowGreaterThanEq creates a checker that checks that the TCP window
+// is greater than or equal to the provided value.
+func TCPWindowGreaterThanEq(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ t.Fatalf("TCP header not found in h: %T", h)
+ }
+
+ if w := tcp.WindowSize(); w < window {
+ t.Errorf("Bad window, got %d, want > %d", w, window)
+ }
+ }
+}
+
+// TCPWindowLessThanEq creates a checker that checks that the tcp window
+// is less than or equal to the provided value.
+func TCPWindowLessThanEq(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ t.Fatalf("TCP header not found in h: %T", h)
+ }
+
+ if w := tcp.WindowSize(); w > window {
+ t.Errorf("Bad window, got %d, want < %d", w, window)
}
}
}
@@ -403,7 +437,7 @@ func TCPFlags(flags uint8) TransportChecker {
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); f != flags {
@@ -420,7 +454,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
tcp, ok := h.(header.TCP)
if !ok {
- return
+ t.Fatalf("TCP header not found in h: %T", h)
}
if f := tcp.Flags(); (f & mask) != (flags & mask) {
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD
new file mode 100644
index 000000000..114d43df3
--- /dev/null
+++ b/pkg/tcpip/faketime/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "faketime",
+ srcs = ["faketime.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_dpjacques_clockwork//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "faketime_test",
+ size = "small",
+ srcs = [
+ "faketime_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip/faketime",
+ ],
+)
diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/faketime/faketime.go
index 92c8cb534..1193f1d7d 100644
--- a/pkg/tcpip/stack/fake_time_test.go
+++ b/pkg/tcpip/faketime/faketime.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+// Package faketime provides a fake clock that implements tcpip.Clock interface.
+package faketime
import (
"container/heap"
@@ -23,7 +24,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-type fakeClock struct {
+// ManualClock implements tcpip.Clock and only advances manually with Advance
+// method.
+type ManualClock struct {
clock clockwork.FakeClock
// mu protects the fields below.
@@ -39,34 +42,35 @@ type fakeClock struct {
waitGroups map[time.Time]*sync.WaitGroup
}
-func newFakeClock() *fakeClock {
- return &fakeClock{
+// NewManualClock creates a new ManualClock instance.
+func NewManualClock() *ManualClock {
+ return &ManualClock{
clock: clockwork.NewFakeClock(),
times: &timeHeap{},
waitGroups: make(map[time.Time]*sync.WaitGroup),
}
}
-var _ tcpip.Clock = (*fakeClock)(nil)
+var _ tcpip.Clock = (*ManualClock)(nil)
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (fc *fakeClock) NowNanoseconds() int64 {
- return fc.clock.Now().UnixNano()
+func (mc *ManualClock) NowNanoseconds() int64 {
+ return mc.clock.Now().UnixNano()
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (fc *fakeClock) NowMonotonic() int64 {
- return fc.NowNanoseconds()
+func (mc *ManualClock) NowMonotonic() int64 {
+ return mc.NowNanoseconds()
}
// AfterFunc implements tcpip.Clock.AfterFunc.
-func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
- until := fc.clock.Now().Add(d)
- wg := fc.addWait(until)
- return &fakeTimer{
- clock: fc,
+func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ until := mc.clock.Now().Add(d)
+ wg := mc.addWait(until)
+ return &manualTimer{
+ clock: mc,
until: until,
- timer: fc.clock.AfterFunc(d, func() {
+ timer: mc.clock.AfterFunc(d, func() {
defer wg.Done()
f()
}),
@@ -75,110 +79,113 @@ func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
// addWait adds an additional wait to the WaitGroup for parallel execution of
// all work scheduled for t. Returns a reference to the WaitGroup modified.
-func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup {
- fc.mu.RLock()
- wg, ok := fc.waitGroups[t]
- fc.mu.RUnlock()
+func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup {
+ mc.mu.RLock()
+ wg, ok := mc.waitGroups[t]
+ mc.mu.RUnlock()
if ok {
wg.Add(1)
return wg
}
- fc.mu.Lock()
- heap.Push(fc.times, t)
- fc.mu.Unlock()
+ mc.mu.Lock()
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
wg = &sync.WaitGroup{}
wg.Add(1)
- fc.mu.Lock()
- fc.waitGroups[t] = wg
- fc.mu.Unlock()
+ mc.mu.Lock()
+ mc.waitGroups[t] = wg
+ mc.mu.Unlock()
return wg
}
// removeWait removes a wait from the WaitGroup for parallel execution of all
// work scheduled for t.
-func (fc *fakeClock) removeWait(t time.Time) {
- fc.mu.RLock()
- defer fc.mu.RUnlock()
+func (mc *ManualClock) removeWait(t time.Time) {
+ mc.mu.RLock()
+ defer mc.mu.RUnlock()
- wg := fc.waitGroups[t]
+ wg := mc.waitGroups[t]
wg.Done()
}
-// advance executes all work that have been scheduled to execute within d from
-// the current fake time. Blocks until all work has completed execution.
-func (fc *fakeClock) advance(d time.Duration) {
+// Advance executes all work that have been scheduled to execute within d from
+// the current time. Blocks until all work has completed execution.
+func (mc *ManualClock) Advance(d time.Duration) {
// Block until all the work is done
- until := fc.clock.Now().Add(d)
+ until := mc.clock.Now().Add(d)
for {
- fc.mu.Lock()
- if fc.times.Len() == 0 {
- fc.mu.Unlock()
- return
+ mc.mu.Lock()
+ if mc.times.Len() == 0 {
+ mc.mu.Unlock()
+ break
}
- t := heap.Pop(fc.times).(time.Time)
+ t := heap.Pop(mc.times).(time.Time)
if t.After(until) {
// No work to do
- heap.Push(fc.times, t)
- fc.mu.Unlock()
- return
+ heap.Push(mc.times, t)
+ mc.mu.Unlock()
+ break
}
- fc.mu.Unlock()
+ mc.mu.Unlock()
- diff := t.Sub(fc.clock.Now())
- fc.clock.Advance(diff)
+ diff := t.Sub(mc.clock.Now())
+ mc.clock.Advance(diff)
- fc.mu.RLock()
- wg := fc.waitGroups[t]
- fc.mu.RUnlock()
+ mc.mu.RLock()
+ wg := mc.waitGroups[t]
+ mc.mu.RUnlock()
wg.Wait()
- fc.mu.Lock()
- delete(fc.waitGroups, t)
- fc.mu.Unlock()
+ mc.mu.Lock()
+ delete(mc.waitGroups, t)
+ mc.mu.Unlock()
+ }
+ if now := mc.clock.Now(); until.After(now) {
+ mc.clock.Advance(until.Sub(now))
}
}
-type fakeTimer struct {
- clock *fakeClock
+type manualTimer struct {
+ clock *ManualClock
timer clockwork.Timer
mu sync.RWMutex
until time.Time
}
-var _ tcpip.Timer = (*fakeTimer)(nil)
+var _ tcpip.Timer = (*manualTimer)(nil)
// Reset implements tcpip.Timer.Reset.
-func (ft *fakeTimer) Reset(d time.Duration) {
- if !ft.timer.Reset(d) {
+func (t *manualTimer) Reset(d time.Duration) {
+ if !t.timer.Reset(d) {
return
}
- ft.mu.Lock()
- defer ft.mu.Unlock()
+ t.mu.Lock()
+ defer t.mu.Unlock()
- ft.clock.removeWait(ft.until)
- ft.until = ft.clock.clock.Now().Add(d)
- ft.clock.addWait(ft.until)
+ t.clock.removeWait(t.until)
+ t.until = t.clock.clock.Now().Add(d)
+ t.clock.addWait(t.until)
}
// Stop implements tcpip.Timer.Stop.
-func (ft *fakeTimer) Stop() bool {
- if !ft.timer.Stop() {
+func (t *manualTimer) Stop() bool {
+ if !t.timer.Stop() {
return false
}
- ft.mu.RLock()
- defer ft.mu.RUnlock()
+ t.mu.RLock()
+ defer t.mu.RUnlock()
- ft.clock.removeWait(ft.until)
+ t.clock.removeWait(t.until)
return true
}
diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go
new file mode 100644
index 000000000..c2704df2c
--- /dev/null
+++ b/pkg/tcpip/faketime/faketime_test.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package faketime_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+)
+
+func TestManualClockAdvance(t *testing.T) {
+ const timeout = time.Millisecond
+ clock := faketime.NewManualClock()
+ start := clock.NowMonotonic()
+ clock.Advance(timeout)
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want {
+ t.Errorf("got = %d, want = %d", got, want)
+ }
+}
+
+func TestManualClockAfterFunc(t *testing.T) {
+ const (
+ timeout1 = time.Millisecond // timeout for counter1
+ timeout2 = 2 * time.Millisecond // timeout for counter2
+ )
+ tests := []struct {
+ name string
+ advance time.Duration
+ wantCounter1 int
+ wantCounter2 int
+ }{
+ {
+ name: "before timeout1",
+ advance: timeout1 - 1,
+ wantCounter1: 0,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout1",
+ advance: timeout1,
+ wantCounter1: 1,
+ wantCounter2: 0,
+ },
+ {
+ name: "timeout2",
+ advance: timeout2,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ {
+ name: "after timeout2",
+ advance: timeout2 + 1,
+ wantCounter1: 1,
+ wantCounter2: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ counter1 := 0
+ counter2 := 0
+ clock.AfterFunc(timeout1, func() {
+ counter1++
+ })
+ clock.AfterFunc(timeout2, func() {
+ counter2++
+ })
+ start := clock.NowMonotonic()
+ clock.Advance(test.advance)
+ if got, want := counter1, test.wantCounter1; got != want {
+ t.Errorf("got counter1 = %d, want = %d", got, want)
+ }
+ if got, want := counter2, test.wantCounter2; got != want {
+ t.Errorf("got counter2 = %d, want = %d", got, want)
+ }
+ if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want {
+ t.Errorf("got elapsed = %d, want = %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index be03fb086..c00bcadfb 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -31,6 +31,27 @@ const (
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
+ // ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an
+ // errant packet's transport layer that an ICMP error type packet should
+ // attempt to send as per RFC 792 (see each type) and RFC 1122
+ // section 3.2.2 which states:
+ // Every ICMP error message includes the Internet header and at
+ // least the first 8 data octets of the datagram that triggered
+ // the error; more than 8 octets MAY be sent; this header and data
+ // MUST be unchanged from the received datagram.
+ //
+ // RFC 792 shows:
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Type | Code | Checksum |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | unused |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Internet Header + 64 bits of Original Data Datagram |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ ICMPv4MinimumErrorPayloadSize = 8
+
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
@@ -39,15 +60,19 @@ const (
icmpv4ChecksumOffset = 2
// icmpv4MTUOffset is the offset of the MTU field
- // in a ICMPv4FragmentationNeeded message.
+ // in an ICMPv4FragmentationNeeded message.
icmpv4MTUOffset = 6
// icmpv4IdentOffset is the offset of the ident field
- // in a ICMPv4EchoRequest/Reply message.
+ // in an ICMPv4EchoRequest/Reply message.
icmpv4IdentOffset = 4
+ // icmpv4PointerOffset is the offset of the pointer field
+ // in an ICMPv4ParamProblem message.
+ icmpv4PointerOffset = 4
+
// icmpv4SequenceOffset is the offset of the sequence field
- // in a ICMPv4EchoRequest/Reply message.
+ // in an ICMPv4EchoRequest/Reply message.
icmpv4SequenceOffset = 6
)
@@ -72,15 +97,23 @@ const (
ICMPv4InfoReply ICMPv4Type = 16
)
+// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
+const (
+ ICMPv4TTLExceeded ICMPv4Code = 0
+)
+
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792.
const (
- ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4NetUnreachable ICMPv4Code = 0
ICMPv4HostUnreachable ICMPv4Code = 1
ICMPv4ProtoUnreachable ICMPv4Code = 2
ICMPv4PortUnreachable ICMPv4Code = 3
ICMPv4FragmentationNeeded ICMPv4Code = 4
)
+// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed.
+const ICMPv4UnusedCode ICMPv4Code = 0
+
// Type is the ICMP type field.
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
@@ -93,6 +126,15 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
+// SetPointer sets the pointer field in a Parameter error packet.
+// This is the first byte of the type specific data field.
+func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
+
+// SetTypeSpecific sets the full 32 bit type specific data field.
+func (b ICMPv4) SetTypeSpecific(val uint32) {
+ binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val)
+}
+
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 20b01d8f4..4eb5abd79 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -54,9 +54,17 @@ const (
// address.
ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
- // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
+ // ICMPv6ErrorHeaderSize is the size of an ICMP error packet header,
+ // as per RFC 4443, Apendix A, item 4 and the errata.
+ // ... all ICMP error messages shall have exactly
+ // 32 bits of type-specific data, so that receivers can reliably find
+ // the embedded invoking packet even when they don't recognize the
+ // ICMP message Type.
+ ICMPv6ErrorHeaderSize = 8
+
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
@@ -69,6 +77,10 @@ const (
// in an ICMPv6 message.
icmpv6ChecksumOffset = 2
+ // icmpv6PointerOffset is the offset of the pointer
+ // in an ICMPv6 Parameter problem message.
+ icmpv6PointerOffset = 4
+
// icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
// PacketTooBig message.
icmpv6MTUOffset = 4
@@ -89,9 +101,10 @@ const (
NDPHopLimit = 255
)
-// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+// ICMPv6Type is the ICMP type field described in RFC 4443.
type ICMPv6Type byte
+// Values for use in the Type field of ICMPv6 packet from RFC 4433.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
@@ -109,7 +122,18 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// ICMPv6Code is the ICMP code field described in RFC 4443.
+// IsErrorType returns true if the receiver is an ICMP error type.
+func (typ ICMPv6Type) IsErrorType() bool {
+ // Per RFC 4443 section 2.1:
+ // ICMPv6 messages are grouped into two classes: error messages and
+ // informational messages. Error messages are identified as such by a
+ // zero in the high-order bit of their message Type field values. Thus,
+ // error messages have message types from 0 to 127; informational
+ // messages have message types from 128 to 255.
+ return typ&0x80 == 0
+}
+
+// ICMPv6Code is the ICMP Code field described in RFC 4443.
type ICMPv6Code byte
// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443
@@ -153,6 +177,11 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
+// SetTypeSpecific sets the full 32 bit type specific data field.
+func (b ICMPv6) SetTypeSpecific(val uint32) {
+ binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
+}
+
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:])
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e8816c3f4..b07d9991d 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -80,7 +80,8 @@ type IPv4Fields struct {
type IPv4 []byte
const (
- // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet;
+ // i.e. a packet header with no options.
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
@@ -327,7 +328,7 @@ func IsV4MulticastAddress(addr tcpip.Address) bool {
}
// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback
-// address (belongs to 127.0.0.1/8 subnet).
+// address (belongs to 127.0.0.0/8 subnet). See RFC 1122 section 3.2.1.3.
func IsV4LoopbackAddress(addr tcpip.Address) bool {
if len(addr) != IPv4AddressSize {
return false
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index cb9225bd7..81e286e80 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -238,6 +238,12 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return 0, false, parse.ARP(pkt)
}
+// ReturnError implements stack.TransportProtocol.ReturnError.
+func (*protocol) ReturnError(*stack.Route, tcpip.ICMPReason, *stack.PacketBuffer) *tcpip.Error {
+ // In ARP, there is no such response so do nothing.
+ return nil
+}
+
// NewProtocol returns an ARP network protocol.
func NewProtocol() stack.NetworkProtocol {
return &protocol{}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index b5659a36b..5fe73315f 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,7 @@
package ipv4
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -105,7 +106,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
localAddr := r.LocalAddress
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) {
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) {
localAddr = ""
}
@@ -131,7 +132,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: dataVV,
})
-
+ // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header
+ // information we will have to change this code to handle the ICMP header
+ // no longer being in the data buffer.
+ replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
// Send out the reply packet.
sent := stats.ICMP.V4PacketsSent
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
@@ -193,3 +197,175 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Invalid.Increment()
}
}
+
+// ======= ICMP Error packet generation =========
+
+// ReturnError implements stack.TransportProtocol.ReturnError.
+func (p *protocol) ReturnError(r *stack.Route, reason tcpip.ICMPReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ switch reason.(type) {
+ case *tcpip.ICMPReasonPortUnreachable:
+ return returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ default:
+ return tcpip.ErrNotSupported
+ }
+}
+
+// icmpReason is a marker interface for IPv4 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv4 and sends it back to the remote device that sent
+// the problematic packet. It incorporates as much of that packet as
+// possible as well as any error metadata as is available. returnError
+// expects pkt to hold a valid IPv4 packet as per the wire format.
+func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ sent := r.Stats().ICMP.V4PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ // We check we are responding only when we are allowed to.
+ // See RFC 1812 section 4.3.2.7 (shown below).
+ //
+ // =========
+ // 4.3.2.7 When Not to Send ICMP Errors
+ //
+ // An ICMP error message MUST NOT be sent as the result of receiving:
+ //
+ // o An ICMP error message, or
+ //
+ // o A packet which fails the IP header validation tests described in
+ // Section [5.2.2] (except where that section specifically permits
+ // the sending of an ICMP error message), or
+ //
+ // o A packet destined to an IP broadcast or IP multicast address, or
+ //
+ // o A packet sent as a Link Layer broadcast or multicast, or
+ //
+ // o Any fragment of a datagram other then the first fragment (i.e., a
+ // packet for which the fragment offset in the IP header is nonzero).
+ //
+ // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
+ // response to a non-initial fragment, but it currently can not happen.
+
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any {
+ return nil
+ }
+
+ networkHeader := pkt.NetworkHeader().View()
+ transportHeader := pkt.TransportHeader().View()
+
+ // Don't respond to icmp error packets.
+ if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) {
+ // TODO(gvisor.dev/issue/3810):
+ // Unfortunately the current stack pretty much always has ICMPv4 headers
+ // in the Data section of the packet but there is no guarantee that is the
+ // case. If this is the case grab the header to make it like all other
+ // packet types. When this is cleaned up the Consume should be removed.
+ if transportHeader.IsEmpty() {
+ var ok bool
+ transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize)
+ if !ok {
+ return nil
+ }
+ } else if transportHeader.Size() < header.ICMPv4MinimumSize {
+ return nil
+ }
+ // We need to decide to explicitly name the packets we can respond to or
+ // the ones we can not respond to. The decision is somewhat arbitrary and
+ // if problems arise this could be reversed. It was judged less of a breach
+ // of protocol to not respond to unknown non-error packets than to respond
+ // to unknown error packets so we take the first approach.
+ switch header.ICMPv4(transportHeader).Type() {
+ case
+ header.ICMPv4EchoReply,
+ header.ICMPv4Echo,
+ header.ICMPv4Timestamp,
+ header.ICMPv4TimestampReply,
+ header.ICMPv4InfoRequest,
+ header.ICMPv4InfoReply:
+ default:
+ // Assume any type we don't know about may be an error type.
+ return nil
+ }
+ } else if transportHeader.IsEmpty() {
+ return nil
+ }
+
+ // Now work out how much of the triggering packet we should return.
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes.
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 and RFC 792 where it mentioned that at
+ // least 8 bytes of the payload must be included. Today linux and other
+ // systems implement the RFC 1812 definition and not the original
+ // requirement. We treat 8 bytes as the minimum but will try send more.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+
+ if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
+ return nil
+ }
+
+ payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, an AF_RAW or AF_PACKET socket may use what the transport
+ // protocol considers an unreachable destination. Thus we deep copy pkt to
+ // prevent multiple ownership and SR errors. The new copy is a vectorized
+ // view with the entire incoming IP packet reassembled and truncated as
+ // required. This is now the payload of the new ICMP packet and no longer
+ // considered a packet in its own right.
+ newHeader := append(buffer.View(nil), networkHeader...)
+ newHeader = append(newHeader, transportHeader...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
+ payload.CapLength(payloadLen)
+
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+ icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
+
+ icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ counter := sent.DstUnreachable
+ icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+
+ if err := r.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ icmpPkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b14b356d6..135444222 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -455,6 +455,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
+ // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
+ // headers, the setting of the transport number here should be
+ // unnecessary and removed.
+ pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt)
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index b14bc98e8..86187aba8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,6 +17,7 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
@@ -160,47 +161,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type testRoute struct {
- stack.Route
-
- linkEP *testutil.TestEndpoint
-}
-
-func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
- // Make the packet and write it.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
- s.CreateNIC(1, testEP)
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- s.AddAddress(1, ipv4.ProtocolNumber, src)
- {
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
- }
- t.Cleanup(func() {
- testEP.Close()
- })
- return testRoute{
- Route: r,
- linkEP: testEP,
- }
-}
-
func TestFragmentation(t *testing.T) {
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
@@ -228,7 +188,8 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
@@ -237,16 +198,16 @@ func TestFragmentation(t *testing.T) {
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("err got %v, want %v", err, nil)
+ t.Errorf("got err = %s, want = nil", err)
}
- if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
- t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+ if got := len(ep.WrittenPackets); got != ft.expectedFrags {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
}
- if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
- compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
+ compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
}
@@ -259,35 +220,30 @@ func TestFragmentationErrors(t *testing.T) {
mtu uint32
transportHeaderLength int
payloadViewsSizes []int
- packetCollectorErrors []*tcpip.Error
+ err *tcpip.Error
+ allowPackets int
}{
- {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
+ {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
+ {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
+ {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
}, pkt)
- for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
- if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
- t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
- }
- }
- // We only need to check that last error because all the ones before are
- // nil.
- if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
- t.Errorf("err got %v, want %v", got, want)
+ if err != ft.err {
+ t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
}
- if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
- t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
})
}
@@ -1052,7 +1008,7 @@ func TestWriteStats(t *testing.T) {
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
- linkEP func() stack.LinkEndpoint
+ allowPackets int
expectSent int
expectDropped int
expectWritten int
@@ -1061,7 +1017,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets,
expectDropped: 0,
expectWritten: nPackets,
@@ -1069,7 +1025,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectDropped: 0,
expectWritten: nPackets - 1,
@@ -1086,10 +1042,10 @@ func TestWriteStats(t *testing.T) {
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = stack.DropTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
+ t.Fatalf("failed to replace table: %s", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: 0,
expectDropped: nPackets,
expectWritten: nPackets,
@@ -1111,10 +1067,10 @@ func TestWriteStats(t *testing.T) {
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
+ t.Fatalf("failed to replace table: %s", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectDropped: 1,
expectWritten: nPackets,
@@ -1150,7 +1106,8 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- rt := buildRoute(t, nil, test.linkEP())
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -1181,101 +1138,37 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
})
- s.CreateNIC(1, linkEP)
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
)
- s.AddAddress(1, ipv4.ProtocolNumber, src)
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ }
{
subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
if err != nil {
- t.Fatal(err)
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
NIC: 1,
}})
}
- rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
}
return rt
}
-// limitedEP is a link endpoint that writes up to a certain number of packets
-// before returning errors.
-type limitedEP struct {
- limit int
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*limitedEP) MTU() uint32 {
- // Give an MTU that won't cause fragmentation for IPv4+UDP.
- return header.IPv4MinimumSize + header.UDPMinimumSize
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- if ep.limit == 0 {
- return 0, tcpip.ErrInvalidEndpointState
- }
- nWritten := ep.limit
- if nWritten > pkts.Len() {
- nWritten = pkts.Len()
- }
- ep.limit -= nWritten
- return nWritten, nil
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*limitedEP) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*limitedEP) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
// limitedMatcher is an iptables matcher that matches after a certain number of
// packets are checked against it.
type limitedMatcher struct {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index cd5fe3ea8..8bd8f5c52 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -35,6 +35,7 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 2b83c421e..072c8ccd7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -318,6 +318,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
})
packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
na.SetSolicitedFlag(solicited)
@@ -438,6 +439,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
Data: pkt.Data,
})
packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
@@ -477,7 +479,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
stack := r.Stack()
// Is the networking stack operating as a router?
- if !stack.Forwarding() {
+ if !stack.Forwarding(ProtocolNumber) {
// ... No, silently drop the packet.
received.RouterOnlyPacketsDroppedByHost.Increment()
return
@@ -637,6 +639,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
})
icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr.SetType(header.ICMPv6NeighborSolicit)
copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
@@ -665,3 +668,123 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
return tcpip.LinkAddress([]byte(nil)), false
}
+
+// ======= ICMP Error packet generation =========
+
+// ReturnError implements stack.TransportProtocol.ReturnError.
+func (p *protocol) ReturnError(r *stack.Route, reason tcpip.ICMPReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ switch reason.(type) {
+ case *tcpip.ICMPReasonPortUnreachable:
+ return returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ default:
+ return tcpip.ErrNotSupported
+ }
+}
+
+// icmpReason is a marker interface for IPv6 specific ICMP errors.
+type icmpReason interface {
+ isICMPReason()
+}
+
+// icmpReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type icmpReasonPortUnreachable struct{}
+
+func (*icmpReasonPortUnreachable) isICMPReason() {}
+
+// returnError takes an error descriptor and generates the appropriate ICMP
+// error packet for IPv6 and sends it.
+func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ stats := r.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
+ // Only send ICMP error if the address is not a multicast v6
+ // address and the source is not the unspecified address.
+ //
+ // TODO(b/164522993) There are exceptions to this rule.
+ // See: point e.3) RFC 4443 section-2.4
+ //
+ // (e) An ICMPv6 error message MUST NOT be originated as a result of
+ // receiving the following:
+ //
+ // (e.1) An ICMPv6 error message.
+ //
+ // (e.2) An ICMPv6 redirect message [IPv6-DISC].
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ //
+ if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+
+ if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
+ // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
+ // Unfortunately at this time ICMP Packets do not have a transport
+ // header separated out. It is in the Data part so we need to
+ // separate it out now. We will just pretend it is a minimal length
+ // ICMP packet as we don't really care if any later bits of a
+ // larger ICMP packet are in the header view or in the Data view.
+ transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize)
+ if !ok {
+ return nil
+ }
+ typ := header.ICMPv6(transport).Type()
+ if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
+ return nil
+ }
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ available := int(mtu) - headerLen
+ if available < header.IPv6MinimumSize {
+ return nil
+ }
+ payloadLen := network.Size() + transport.Size() + pkt.Data.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ payload.CapLength(payloadLen)
+
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+ newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+
+ icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
+ counter := sent.DstUnreachable
+ err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
+ if err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ counter.Increment()
+ return nil
+}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 8112ed051..0f50bfb8e 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -728,7 +728,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
})
if isRouter {
// Enabling forwarding makes the stack act as a router.
- s.SetForwarding(true)
+ s.SetForwarding(ProtocolNumber, true)
}
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index ee64d92d8..5b1cca180 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -348,7 +348,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
it, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -476,6 +476,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.Data = extHdr.Buf
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 9eea1de8d..7d138dadb 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -1715,7 +1716,7 @@ func TestWriteStats(t *testing.T) {
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
- linkEP func() stack.LinkEndpoint
+ allowPackets int
expectSent int
expectDropped int
expectWritten int
@@ -1724,7 +1725,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets,
expectDropped: 0,
expectWritten: nPackets,
@@ -1732,7 +1733,7 @@ func TestWriteStats(t *testing.T) {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectDropped: 0,
expectWritten: nPackets - 1,
@@ -1752,7 +1753,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: 0,
expectDropped: nPackets,
expectWritten: nPackets,
@@ -1777,7 +1778,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
- linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectDropped: 1,
expectWritten: nPackets,
@@ -1812,7 +1813,8 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- rt := buildRoute(t, nil, test.linkEP())
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -1843,100 +1845,37 @@ func TestWriteStats(t *testing.T) {
}
}
-func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
})
- s.CreateNIC(1, linkEP)
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC(1, _) failed: %s", err)
+ }
const (
src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
- s.AddAddress(1, ProtocolNumber, src)
+ if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ }
{
subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
if err != nil {
- t.Fatal(err)
+ t.Fatalf("NewSubnet(_, _) failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
NIC: 1,
}})
}
- rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */)
+ rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
}
return rt
}
-// limitedEP is a link endpoint that writes up to a certain number of packets
-// before returning errors.
-type limitedEP struct {
- limit int
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (*limitedEP) MTU() uint32 {
- return header.IPv6MinimumMTU
-}
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- if ep.limit == 0 {
- return 0, tcpip.ErrInvalidEndpointState
- }
- nWritten := ep.limit
- if nWritten > pkts.Len() {
- nWritten = pkts.Len()
- }
- ep.limit -= nWritten
- return nWritten, nil
-}
-
-// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
-func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- if ep.limit == 0 {
- return tcpip.ErrInvalidEndpointState
- }
- ep.limit--
- return nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*limitedEP) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*limitedEP) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
// limitedMatcher is an iptables matcher that matches after a certain number of
// packets are checked against it.
type limitedMatcher struct {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 480c495fa..7434df4a1 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -958,7 +958,7 @@ func TestNDPValidation(t *testing.T) {
if isRouter {
// Enabling forwarding makes the stack act as a router.
- s.SetForwarding(true)
+ s.SetForwarding(ProtocolNumber, true)
}
stats := s.Stats().ICMP.V6PacketsReceived
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index e218563d0..c9e57dc0d 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -7,11 +7,14 @@ go_library(
srcs = [
"testutil.go",
],
- visibility = ["//pkg/tcpip/network/ipv4:__pkg__"],
+ visibility = [
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index bf5ce74be..7cc52985e 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -22,48 +22,100 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// TestEndpoint is an endpoint used for testing, it stores packets written to it
-// and can mock errors.
-type TestEndpoint struct {
- *channel.Endpoint
-
- // WrittenPackets is where we store packets written via WritePacket().
+// MockLinkEndpoint is an endpoint used for testing, it stores packets written
+// to it and can mock errors.
+type MockLinkEndpoint struct {
+ // WrittenPackets is where packets written to the endpoint are stored.
WrittenPackets []*stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
+ mtu uint32
+ err *tcpip.Error
+ allowPackets int
}
-// NewTestEndpoint creates a new TestEndpoint endpoint.
+// NewMockLinkEndpoint creates a new MockLinkEndpoint.
//
-// packetCollectorErrors can be used to set error values and each call to
-// WritePacket will remove the first one from the slice and return it until
-// the slice is empty - at that point it will return nil every time.
-func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint {
- return &TestEndpoint{
- Endpoint: ep,
- WrittenPackets: make([]*stack.PacketBuffer, 0),
- packetCollectorErrors: packetCollectorErrors,
+// err is the error that will be returned once allowPackets packets are written
+// to the endpoint.
+func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+ return &MockLinkEndpoint{
+ mtu: mtu,
+ err: err,
+ allowPackets: allowPackets,
+ }
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
}
+ ep.allowPackets--
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+ return nil
}
-// WritePacket stores outbound packets and may return an error if one was
-// injected.
-func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- e.WrittenPackets = append(e.WrittenPackets, pkt)
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var n int
- if len(e.packetCollectorErrors) > 0 {
- nextError := e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- return nextError
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := ep.WritePacket(r, gso, protocol, pkt); err != nil {
+ return n, err
+ }
+ n++
}
+ return n, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if ep.allowPackets == 0 {
+ return ep.err
+ }
+ ep.allowPackets--
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+ ep.WrittenPackets = append(ep.WrittenPackets, pkt)
+
return nil
}
+// Attach implements LinkEndpoint.Attach.
+func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*MockLinkEndpoint) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*MockLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
// how many random bytes will be copied in the Transport Header.
// extraHeaderReserveLength indicates how much extra space will be reserved for
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 900938dd1..7f1d79115 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -138,7 +138,6 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "fake_time_test.go",
"forwarder_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
@@ -152,8 +151,8 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
- "@com_github_dpjacques_clockwork//:go_default_library",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 54759091a..e30927821 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -145,6 +145,10 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
+func (*fwdTestNetworkProtocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
return &fwdTestNetworkEndpoint{
nicID: nicID,
@@ -316,7 +320,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
}
// Enable forwarding.
- s.SetForwarding(true)
+ s.SetForwarding(proto.Number(), true)
// NIC 1 has the link address "a", and added the network address 1.
ep1 = &fwdTestLinkEndpoint{
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index b0873d1af..97ca00d16 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -817,7 +817,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a stack-wide configuration, so we check
// stack's forwarding flag to determine if the NIC is a routing
// interface.
- if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding {
+ if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) {
return
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 67dc5364f..5e43a9b0b 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1120,7 +1120,7 @@ func TestNoRouterDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1365,7 +1365,7 @@ func TestNoPrefixDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1723,7 +1723,7 @@ func TestNoAutoGenAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -4640,7 +4640,7 @@ func TestCleanupNDPState(t *testing.T) {
name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
keepAutoGenLinkLocal: true,
maxAutoGenAddrEvents: 4,
@@ -5286,11 +5286,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(false)
+ s.SetForwarding(ipv6.ProtocolNumber, false)
},
stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
},
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index b4fa69e3e..a0b7da5cd 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -30,6 +30,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
const (
@@ -239,7 +240,7 @@ type entryEvent struct {
func TestNeighborCacheGetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
if got, want := neigh.config(), c; got != want {
@@ -257,7 +258,7 @@ func TestNeighborCacheGetConfig(t *testing.T) {
func TestNeighborCacheSetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
c.MinRandomFactor = 1
@@ -279,7 +280,7 @@ func TestNeighborCacheSetConfig(t *testing.T) {
func TestNeighborCacheEntry(t *testing.T) {
c := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -298,7 +299,7 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -339,7 +340,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -358,7 +359,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -409,7 +410,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
type testContext struct {
- clock *fakeClock
+ clock *faketime.ManualClock
neigh *neighborCache
store *testEntryStore
linkRes *testNeighborResolver
@@ -418,7 +419,7 @@ type testContext struct {
func newTestContext(c NUDConfigurations) testContext {
nudDisp := &testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nudDisp, c, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -454,7 +455,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
var wantEvents []testEntryEventInfo
@@ -567,7 +568,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -803,7 +804,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(typicalLatency)
+ c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -876,7 +877,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -902,7 +903,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
if doneCh == nil {
t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
@@ -944,7 +945,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -974,7 +975,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
// Remove the waker before the neighbor cache has the opportunity to send a
// notification.
neigh.removeWaker(entry.Addr, &w)
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
@@ -1073,7 +1074,7 @@ func TestNeighborCacheClear(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1092,7 +1093,7 @@ func TestNeighborCacheClear(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -1188,7 +1189,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(typicalLatency)
+ c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1249,7 +1250,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
config.MaxRandomFactor = 1
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1277,7 +1278,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1325,7 +1326,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1412,7 +1413,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1440,7 +1441,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Wait()
// Process all the requests for a single entry concurrently
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
}
// All goroutines add in the same order and add more values than can fit in
@@ -1472,7 +1473,7 @@ func TestNeighborCacheReplace(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1491,7 +1492,7 @@ func TestNeighborCacheReplace(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1541,7 +1542,7 @@ func TestNeighborCacheReplace(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(config.DelayFirstProbeTime + typicalLatency)
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
select {
case <-doneCh:
default:
@@ -1552,7 +1553,7 @@ func TestNeighborCacheReplace(t *testing.T) {
// Verify the entry's new link address
{
e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
@@ -1572,7 +1573,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
@@ -1595,7 +1596,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
if err != nil {
t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
@@ -1618,7 +1619,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
}
@@ -1636,7 +1637,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
config := DefaultNUDConfigurations()
config.RetransmitTimer = time.Millisecond // small enough to cause timeout
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nil, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1654,7 +1655,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
}
@@ -1664,7 +1665,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
// resolved immediately and don't send resolution requests.
func TestNeighborCacheStaticResolution(t *testing.T) {
config := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nil, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 0068cacb8..213646160 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -73,8 +73,7 @@ const (
type neighborEntry struct {
neighborEntryEntry
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
+ nic *NIC
// linkRes provides the functionality to send reachability probes, used in
// Neighbor Unreachability Detection.
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index b769fb2fa..e530ec7ea 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -27,6 +27,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
const (
@@ -221,8 +222,8 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
return entryTestNetNumber
}
-func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) {
- clock := newFakeClock()
+func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) {
+ clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
id: entryTestNICID,
@@ -267,7 +268,7 @@ func TestEntryInitiallyUnknown(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// No probes should have been sent.
linkRes.mu.Lock()
@@ -300,7 +301,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(time.Hour)
+ clock.Advance(time.Hour)
// No probes should have been sent.
linkRes.mu.Lock()
@@ -410,7 +411,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
updatedAt := e.neigh.UpdatedAt
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// UpdatedAt should remain the same during address resolution.
wantProbes := []entryTestProbeInfo{
@@ -439,7 +440,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// UpdatedAt should change after failing address resolution. Timing out after
// sending the last probe transitions the entry to Failed.
@@ -459,7 +460,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
}
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
@@ -748,7 +749,7 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e.mu.Unlock()
waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The Incomplete-to-Incomplete state transition is tested here by
@@ -983,7 +984,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1612,7 +1613,7 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1706,7 +1707,7 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1989,7 +1990,7 @@ func TestEntryDelayToProbe(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2069,7 +2070,7 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2166,7 +2167,7 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2267,7 +2268,7 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2364,7 +2365,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// Probe caused by the Delay-to-Probe transition
@@ -2398,7 +2399,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2463,7 +2464,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2503,7 +2504,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2575,7 +2576,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2612,7 +2613,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2682,7 +2683,7 @@ func TestEntryProbeToFailed(t *testing.T) {
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2787,7 +2788,7 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 204bfc433..06d70dd1c 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -337,7 +337,7 @@ func (n *NIC) enable() *tcpip.Error {
// does. That is, routers do not learn from RAs (e.g. on-link prefixes
// and default routers). Therefore, soliciting RAs from other routers on
// a link is unnecessary for routers.
- if !n.stack.forwarding {
+ if !n.stack.Forwarding(header.IPv6ProtocolNumber) {
n.mu.ndp.startSolicitingRouters()
}
@@ -1242,9 +1242,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
local = n.linkEP.LinkAddress()
}
- // Are any packet sockets listening for this network protocol?
+ // Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet sockets that maybe listening for all protocols.
+ // Add any other packet type sockets that may be listening for all protocols.
packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
@@ -1265,6 +1265,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
return
}
if hasTransportHdr {
+ pkt.TransportProtocolNumber = transProtoNum
// Parse the transport header if present.
if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
state.proto.Parse(pkt)
@@ -1303,7 +1304,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// packet and forward it to the NIC.
//
// TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding() {
+ if n.stack.Forwarding(protocol) {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
@@ -1330,6 +1331,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
// TODO(b/128629022): move this logic to route.WritePacket.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
@@ -1452,10 +1454,28 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
}
- // We could not find an appropriate destination for this packet, so
- // deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
+ // We could not find an appropriate destination for this packet so
+ // give the protocol specific error handler a chance to handle it.
+ // If it doesn't handle it then we should do so.
+ switch transProto.HandleUnknownDestinationPacket(r, id, pkt) {
+ case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
+ case UnknownDestinationPacketUnhandled:
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ np, ok := n.stack.networkProtocols[r.NetProto]
+ if !ok {
+ // For this to happen stack.makeRoute() must have been called with the
+ // incorrect protocol number. Since we have successfully completed
+ // network layer processing this should be impossible.
+ panic(fmt.Sprintf("expected stack to have a NetworkProtocol for proto = %d", r.NetProto))
+ }
+
+ _ = np.ReturnError(r, &tcpip.ICMPReasonPortUnreachable{}, pkt)
+ case UnknownDestinationPacketHandled:
}
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index dd6474297..ef6e63b3e 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -221,6 +221,11 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo
return 0, false, false
}
+// ReturnError implements NetworkProtocol.ReturnError.
+func (*testIPv6Protocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
// LinkAddressProtocol implements LinkAddressResolver.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 1932aaeb7..a7d9d59fa 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -80,11 +80,17 @@ type PacketBuffer struct {
// data are held in the same underlying buffer storage.
header buffer.Prependable
- // NetworkProtocolNumber is only valid when NetworkHeader is set.
+ // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty()
+ // returns false.
// TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
// numbers in registration APIs that take a PacketBuffer.
NetworkProtocolNumber tcpip.NetworkProtocolNumber
+ // TransportProtocol is only valid if it is non zero.
+ // TODO(gvisor.dev/issue/3810): This and the network protocol number should
+ // be moved into the headerinfo. This should resolve the validity issue.
+ TransportProtocolNumber tcpip.TransportProtocolNumber
+
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
Hash uint32
@@ -234,16 +240,17 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
}
return newPk
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 4fa86a3ac..77640cd8a 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -125,6 +125,26 @@ type PacketEndpoint interface {
HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
+// UnknownDestinationPacketDisposition enumerates the possible return vaues from
+// HandleUnknownDestinationPacket().
+type UnknownDestinationPacketDisposition int
+
+const (
+ // UnknownDestinationPacketMalformed denotes that the packet was malformed
+ // and no further processing should be attempted other than updating
+ // statistics.
+ UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota
+
+ // UnknownDestinationPacketUnhandled tells the caller that the packet was
+ // well formed but that the issue was not handled and the stack should take
+ // the default action.
+ UnknownDestinationPacketUnhandled
+
+ // UnknownDestinationPacketHandled tells the caller that it should do
+ // no further processing.
+ UnknownDestinationPacketHandled
+)
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
@@ -147,14 +167,12 @@ type TransportProtocol interface {
ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
- // protocol but that don't match any existing endpoint. For example,
- // it is targeted at a port that have no listeners.
+ // protocol that don't match any existing endpoint. For example,
+ // it is targeted at a port that has no listeners.
//
- // The return value indicates whether the packet was well-formed (for
- // stats purposes only).
- //
- // HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // the issue.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -324,6 +342,19 @@ type NetworkProtocol interface {
// does not encapsulate anything).
// - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
+
+ // ReturnError attempts to send a suitable error message to the sender
+ // of a received packet.
+ // - pkt holds the problematic packet.
+ // - reason indicates what the reason for wanting a message is.
+ // - route is the routing information for the received packet
+ // ReturnError returns an error if the send failed and nil on success.
+ // Note that deciding to deliberately send no message is a success.
+ //
+ // TODO(gvisor.dev/issues/3871): This method should be removed or simplified
+ // after all (or all but one) of the ICMP error dispatch occurs through the
+ // protocol specific modules. May become SendPortNotFound(r, pkt).
+ ReturnError(r *Route, reason tcpip.ICMPReason, pkt *PacketBuffer) *tcpip.Error
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6a683545d..e7b7e95d4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -144,10 +144,7 @@ type TCPReceiverState struct {
// PendingBufUsed is the number of bytes pending in the receive
// queue.
- PendingBufUsed seqnum.Size
-
- // PendingBufSize is the size of the socket receive buffer.
- PendingBufSize seqnum.Size
+ PendingBufUsed int
}
// TCPSenderState holds a copy of the internal state of the sender for
@@ -405,6 +402,13 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // forwarding contains the whether packet forwarding is enabled or not for
+ // different network protocols.
+ forwarding struct {
+ sync.RWMutex
+ protocols map[tcpip.NetworkProtocolNumber]bool
+ }
+
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
rawFactory RawFactory
@@ -415,9 +419,8 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -749,6 +752,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
+ s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool)
// Add specified network protocols.
for _, netProto := range opts.NetworkProtocols {
@@ -866,46 +870,42 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables the packet forwarding between NICs.
-//
-// When forwarding becomes enabled, any host-only state on all NICs will be
-// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
-// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
-// Solicitations will be stopped.
-func (s *Stack) SetForwarding(enable bool) {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.Lock()
- defer s.mu.Unlock()
+// SetForwarding enables or disables packet forwarding between NICs.
+func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) {
+ s.forwarding.Lock()
+ defer s.forwarding.Unlock()
- // If forwarding status didn't change, do nothing further.
- if s.forwarding == enable {
+ // If this stack does not support the protocol, do nothing.
+ if _, ok := s.networkProtocols[protocol]; !ok {
return
}
- s.forwarding = enable
-
- // If this stack does not support IPv6, do nothing further.
- if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok {
+ // If the forwarding value for this protocol hasn't changed then do
+ // nothing.
+ if forwarding := s.forwarding.protocols[protocol]; forwarding == enable {
return
}
- if enable {
- for _, nic := range s.nics {
- nic.becomeIPv6Router()
- }
- } else {
- for _, nic := range s.nics {
- nic.becomeIPv6Host()
+ s.forwarding.protocols[protocol] = enable
+
+ if protocol == header.IPv6ProtocolNumber {
+ if enable {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Router()
+ }
+ } else {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Host()
+ }
}
}
}
-// Forwarding returns if the packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding() bool {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.forwarding
+// Forwarding returns if packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ s.forwarding.RLock()
+ defer s.forwarding.RUnlock()
+ return s.forwarding.protocols[protocol]
}
// SetRouteTable assigns the route table to be used by this stack. It
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 60b54c244..9ef6787c6 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -216,13 +216,18 @@ func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption)
}
}
-// Close implements TransportProtocol.Close.
+// ReturnError implements NetworkProtocol.ReturnError
+func (*fakeNetworkProtocol) ReturnError(*stack.Route, tcpip.ICMPReason, *stack.PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
func (*fakeNetworkProtocol) Close() {}
-// Wait implements TransportProtocol.Wait.
+// Wait implements NetworkProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
-// Parse implements TransportProtocol.Parse.
+// Parse implements NetworkProtocol.Parse.
func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen)
if !ok {
@@ -2091,7 +2096,7 @@ func TestNICForwarding(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index ef3457e32..cbb34d224 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -287,8 +287,8 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
- return true
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
@@ -549,7 +549,7 @@ func TestTransportForwarding(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
// TODO(b/123449044): Change this to a channel NIC.
ep1 := loopback.New()
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 464608dee..fa73cfa47 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1987,3 +1987,14 @@ func DeleteDanglingEndpoint(e Endpoint) {
// AsyncLoading is the global barrier for asynchronous endpoint loading
// activities.
var AsyncLoading sync.WaitGroup
+
+// ICMPReason is a marker interface for network protocol agnostic ICMP errors.
+type ICMPReason interface {
+ isICMP()
+}
+
+// ICMPReasonPortUnreachable is an error where the transport protocol has no
+// listener and no alternative means to inform the sender.
+type ICMPReasonPortUnreachable struct{}
+
+func (*ICMPReasonPortUnreachable) isICMP() {}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 31116309e..41eb0ca44 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -446,6 +446,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
pkt.Owner = owner
icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
copy(icmpv4, data)
// Set the ident to the user-specified port. Sequence number should
// already be set by the user.
@@ -478,6 +479,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
})
icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
copy(icmpv6, data)
// Set the ident. Sequence number is provided by the user.
icmpv6.SetIdent(ident)
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index bb11e4e83..941c3c08d 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,8 +104,8 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
- return true
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
// SetOption implements stack.TransportProtocol.SetOption.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 4778e7b1c..518449602 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -94,6 +94,7 @@ go_test(
shard_count = 10,
deps = [
":tcp",
+ "//pkg/rand",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 09d53d158..6891fd245 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -747,6 +747,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV
func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
optLen := len(tf.opts)
tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen))
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
tcp.Encode(&header.TCPFields{
SrcPort: tf.id.LocalPort,
DstPort: tf.id.RemotePort,
@@ -897,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -1002,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
// Bootstrap the auto tuning algorithm. Starting at zero will
// result in a really large receive window after the first auto
// tuning adjustment.
@@ -1135,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
}
cont, err := e.handleSegment(s)
+ s.decRef()
if err != nil {
- s.decRef()
return err
}
if !cont {
- s.decRef()
return nil
}
}
@@ -1220,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
return true, nil
}
+ // Increase counter if after processing the segment we would potentially
+ // advertise a zero window.
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above {
+ e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
@@ -1232,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// or a notification from the protocolMainLoop (caller goroutine).
// This means that with this return, the segment dequeue below can
// never occur on a closed endpoint.
- s.decRef()
return false, nil
}
@@ -1424,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.rcv.nonZeroWindow()
}
- if n&notifyReceiveWindowChanged != 0 {
- e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
- }
-
if n&notifyMTUChanged != 0 {
e.sndBufMu.Lock()
count := e.packetTooBigCount
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 94207c141..560b4904c 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -78,8 +78,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv4(t, c.GetPacket(), ackCheckers...)
@@ -185,8 +185,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
ackCheckers := append(checkers, checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
))
checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
@@ -283,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -319,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
}
@@ -352,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
@@ -492,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
+ checker.TCPAckNum(uint32(irs)+1),
),
)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 120483838..87db13720 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -63,6 +63,17 @@ const (
StateClosing
)
+const (
+ // rcvAdvWndScale is used to split the available socket buffer into
+ // application buffer and the window to be advertised to the peer. This is
+ // currently hard coded to split the available space equally.
+ rcvAdvWndScale = 1
+
+ // SegOverheadFactor is used to multiply the value provided by the
+ // user on a SetSockOpt for setting the socket send/receive buffer sizes.
+ SegOverheadFactor = 2
+)
+
// connected returns true when s is one of the states representing an
// endpoint connected to a peer.
func (s EndpointState) connected() bool {
@@ -149,7 +160,6 @@ func (s EndpointState) String() string {
// Reasons for notifying the protocol goroutine.
const (
notifyNonZeroReceiveWindow = 1 << iota
- notifyReceiveWindowChanged
notifyClose
notifyMTUChanged
notifyDrain
@@ -384,13 +394,26 @@ type endpoint struct {
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
- rcvList segmentList `state:"wait"`
- rcvClosed bool
- rcvBufSize int
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ // rcvBufSize is the total size of the receive buffer.
+ rcvBufSize int
+ // rcvBufUsed is the actual number of payload bytes held in the receive buffer
+ // not counting any overheads of the segments itself. NOTE: This will always
+ // be strictly <= rcvMemUsed below.
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
+ // rcvMemUsed tracks the total amount of memory in use by received segments
+ // held in rcvList, pendingRcvdSegments and the segment queue. This is used to
+ // compute the window and the actual available buffer space. This is distinct
+ // from rcvBufUsed above which is the actual number of payload bytes held in
+ // the buffer not including any segment overheads.
+ //
+ // rcvMemUsed must be accessed atomically.
+ rcvMemUsed int32
+
// mu protects all endpoint fields unless documented otherwise. mu must
// be acquired before interacting with the endpoint fields.
mu sync.Mutex `state:"nosave"`
@@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.probe = p
}
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
@@ -1129,10 +1152,16 @@ func (e *endpoint) cleanupLocked() {
tcpip.DeleteDanglingEndpoint(e)
}
+// wndFromSpace returns the window that we can advertise based on the available
+// receive buffer space.
+func wndFromSpace(space int) int {
+ return space / (1 << rcvAdvWndScale)
+}
+
// initialReceiveWindow returns the initial receive window to advertise in the
// SYN/SYN-ACK.
func (e *endpoint) initialReceiveWindow() int {
- rcvWnd := e.receiveBufferAvailable()
+ rcvWnd := wndFromSpace(e.receiveBufferAvailable())
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
@@ -1209,14 +1238,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = rcvWnd
- availAfter := e.receiveBufferAvailableLocked()
- mask := uint32(notifyReceiveWindowChanged)
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
- e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -1293,18 +1320,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
v := views[s.viewToDeliver]
s.viewToDeliver++
+ var delta int
if s.viewToDeliver >= len(views) {
e.rcvList.Remove(s)
+ // We only free up receive buffer space when the segment is released as the
+ // segment is still holding on to the views even though some views have been
+ // read out to the user.
+ delta = s.segMemSize()
s.decRef()
}
e.rcvBufUsed -= len(v)
-
// If the window was small before this read and if the read freed up
// enough buffer space, to either fit an aMSS or half a receive buffer
// (whichever smaller), then notify the protocol goroutine to send a
// window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
+ if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1481,11 +1512,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
}
// windowCrossedACKThresholdLocked checks if the receive window to be announced
-// now would be under aMSS or under half receive buffer, whichever smaller. This
-// is useful as a receive side silly window syndrome prevention mechanism. If
-// window grows to reasonable value, we should send ACK to the sender to inform
-// the rx space is now large. We also want ensure a series of small read()'s
-// won't trigger a flood of spurious tiny ACK's.
+// would be under aMSS or under the window derived from half receive buffer,
+// whichever smaller. This is useful as a receive side silly window syndrome
+// prevention mechanism. If window grows to reasonable value, we should send ACK
+// to the sender to inform the rx space is now large. We also want ensure a
+// series of small read()'s won't trigger a flood of spurious tiny ACK's.
//
// For large receive buffers, the threshold is aMSS - once reader reads more
// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
@@ -1496,17 +1527,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := e.receiveBufferAvailableLocked()
+ newAvail := wndFromSpace(e.receiveBufferAvailableLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
}
-
threshold := int(e.amss)
- if threshold > e.rcvBufSize/2 {
- threshold = e.rcvBufSize / 2
+ // rcvBufFraction is the inverse of the fraction of receive buffer size that
+ // is used to decide if the available buffer space is now above it.
+ const rcvBufFraction = 2
+ if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold {
+ threshold = wndThreshold
}
-
switch {
case oldAvail < threshold && newAvail >= threshold:
return true, true
@@ -1636,17 +1668,23 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Make sure the receive buffer size is within the min and max
// allowed.
var rs tcpip.TCPReceiveBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err))
+ }
+
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < rs.Min {
v = rs.Min
}
- if v > rs.Max {
- v = rs.Max
- }
+ } else {
+ v = math.MaxInt32
}
- mask := uint32(notifyReceiveWindowChanged)
-
e.LockUser()
e.rcvListMu.Lock()
@@ -1660,14 +1698,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
v = 1 << scale
}
- // Make sure 2*size doesn't overflow.
- if v > math.MaxInt32/2 {
- v = math.MaxInt32 / 2
- }
-
- availBefore := e.receiveBufferAvailableLocked()
+ availBefore := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvBufSize = v
- availAfter := e.receiveBufferAvailableLocked()
+ availAfter := wndFromSpace(e.receiveBufferAvailableLocked())
e.rcvAutoParams.disabled = true
@@ -1675,24 +1708,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// syndrome prevetion, when our available space grows above aMSS
// or half receive buffer, whichever smaller.
if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
- mask |= notifyNonZeroReceiveWindow
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
e.rcvListMu.Unlock()
e.UnlockUser()
- e.notifyProtocolGoroutine(mask)
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
var ss tcpip.TCPSendBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err))
+ }
+
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ if v < math.MaxInt32/SegOverheadFactor {
+ v *= SegOverheadFactor
if v < ss.Min {
v = ss.Min
}
- if v > ss.Max {
- v = ss.Max
- }
+ } else {
+ v = math.MaxInt32
}
e.sndBufMu.Lock()
@@ -2699,13 +2739,8 @@ func (e *endpoint) updateSndBufferUsage(v int) {
func (e *endpoint) readyToRead(s *segment) {
e.rcvListMu.Lock()
if s != nil {
+ e.rcvBufUsed += s.payloadSize()
s.incRef()
- e.rcvBufUsed += s.data.Size()
- // Increase counter if the receive window falls down below MSS
- // or half receive buffer size, whichever smaller.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
@@ -2720,15 +2755,17 @@ func (e *endpoint) readyToRead(s *segment) {
func (e *endpoint) receiveBufferAvailableLocked() int {
// We may use more bytes than the buffer size when the receive buffer
// shrinks.
- if e.rcvBufUsed >= e.rcvBufSize {
+ memUsed := e.receiveMemUsed()
+ if memUsed >= e.rcvBufSize {
return 0
}
- return e.rcvBufSize - e.rcvBufUsed
+ return e.rcvBufSize - memUsed
}
// receiveBufferAvailable calculates how many bytes are still available in the
-// receive buffer.
+// receive buffer based on the actual memory used by all segments held in
+// receive buffer/pending and segment queue.
func (e *endpoint) receiveBufferAvailable() int {
e.rcvListMu.Lock()
available := e.receiveBufferAvailableLocked()
@@ -2736,14 +2773,35 @@ func (e *endpoint) receiveBufferAvailable() int {
return available
}
+// receiveBufferUsed returns the amount of in-use receive buffer.
+func (e *endpoint) receiveBufferUsed() int {
+ e.rcvListMu.Lock()
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+ return used
+}
+
+// receiveBufferSize returns the current size of the receive buffer.
func (e *endpoint) receiveBufferSize() int {
e.rcvListMu.Lock()
size := e.rcvBufSize
e.rcvListMu.Unlock()
-
return size
}
+// receiveMemUsed returns the total memory in use by segments held by this
+// endpoint.
+func (e *endpoint) receiveMemUsed() int {
+ return int(atomic.LoadInt32(&e.rcvMemUsed))
+}
+
+// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed.
+func (e *endpoint) updateReceiveMemUsed(delta int) {
+ atomic.AddInt32(&e.rcvMemUsed, int32(delta))
+}
+
+// maxReceiveBufferSize returns the stack wide maximum receive buffer size for
+// an endpoint.
func (e *endpoint) maxReceiveBufferSize() int {
var rs tcpip.TCPReceiveBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
@@ -2894,7 +2952,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
RcvAcc: e.rcv.rcvAcc,
RcvWndScale: e.rcv.rcvWndScale,
PendingBufUsed: e.rcv.pendingBufUsed,
- PendingBufSize: e.rcv.pendingBufSize,
}
// Copy sender state.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 41d0050f3..b25431467 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() {
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets.
- e.segmentQueue.setLimit(0)
+ e.segmentQueue.freeze()
e.mu.Lock()
defer e.mu.Unlock()
@@ -178,7 +178,7 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 74a17af79..371067048 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -201,21 +201,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
- return false
+ return stack.UnknownDestinationPacketMalformed
}
- // There's nothing to do if this is already a reset packet.
- if s.flagIsSet(header.TCPFlagRst) {
- return true
+ if !s.flagIsSet(header.TCPFlagRst) {
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
}
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
- return true
+ return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index cfd43b5e3..4aafb4d22 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -47,22 +47,24 @@ type receiver struct {
closed bool
+ // pendingRcvdSegments is bounded by the receive buffer size of the
+ // endpoint.
pendingRcvdSegments segmentHeap
- pendingBufUsed seqnum.Size
- pendingBufSize seqnum.Size
+ // pendingBufUsed tracks the total number of bytes (including segment
+ // overhead) currently queued in pendingRcvdSegments.
+ pendingBufUsed int
// Time when the last ack was received.
lastRcvdAckTime time.Time `state:".(unixTime)"`
}
-func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
lastRcvdAckTime: time.Now(),
}
}
@@ -85,15 +87,23 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- // Calculate the window size based on the available buffer space.
- receiveBufferAvailable := r.ep.receiveBufferAvailable()
- acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
- if r.rcvAcc.LessThan(acc) {
- r.rcvAcc = acc
+ avail := wndFromSpace(r.ep.receiveBufferAvailable())
+ acc := r.rcvNxt.Add(seqnum.Size(avail))
+ newWnd := r.rcvNxt.Size(acc)
+ curWnd := r.rcvNxt.Size(r.rcvAcc)
+
+ // Update rcvAcc only if new window is > previously advertised window. We
+ // should never shrink the acceptable sequence space once it has been
+ // advertised the peer. If we shrink the acceptable sequence space then we
+ // would end up dropping bytes that might already be in flight.
+ if newWnd > curWnd {
+ r.rcvAcc = r.rcvNxt.Add(newWnd)
+ } else {
+ newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
- r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
+ r.rcvWnd = newWnd
return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
}
@@ -195,7 +205,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize()
r.pendingRcvdSegments[i].decRef()
+
// Note that slice truncation does not allow garbage collection of
// truncated items, thus truncated items must be set to nil to avoid
// memory leaks.
@@ -384,10 +396,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
- // We only store the segment if it's within our buffer
- // size limit.
- if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += seqnum.Size(s.segMemSize())
+ // We only store the segment if it's within our buffer size limit.
+ //
+ // Only use 75% of the receive buffer queue for out-of-order
+ // segments. This ensures that we always leave some space for the inorder
+ // segments to arrive allowing pending segments to be processed and
+ // delivered to the user.
+ if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 {
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed += s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -421,7 +439,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= seqnum.Size(s.segMemSize())
+ r.ep.rcvListMu.Lock()
+ r.pendingBufUsed -= s.segMemSize()
+ r.ep.rcvListMu.Unlock()
s.decRef()
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 94307d31a..13acaf753 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "fmt"
"sync/atomic"
"time"
@@ -24,6 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// queueFlags are used to indicate which queue of an endpoint a particular segment
+// belongs to. This is used to track memory accounting correctly.
+type queueFlags uint8
+
+const (
+ recvQ queueFlags = 1 << iota
+ sendQ
+)
+
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
// segment is mostly immutable, the only field allowed to change is viewToDeliver.
@@ -32,6 +42,8 @@ import (
type segment struct {
segmentEntry
refCnt int32
+ ep *endpoint
+ qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
@@ -100,6 +112,8 @@ func (s *segment) clone() *segment {
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
+ ep: s.ep,
+ qFlags: s.qFlags,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool {
return s.flags&flags == flags
}
+// setOwner sets the owning endpoint for this segment. Its required
+// to be called to ensure memory accounting for receive/send buffer
+// queues is done properly.
+func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) {
+ switch qFlags {
+ case recvQ:
+ ep.updateReceiveMemUsed(s.segMemSize())
+ case sendQ:
+ // no memory account for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b", qFlags))
+ }
+ s.ep = ep
+ s.qFlags = qFlags
+}
+
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ if s.ep != nil {
+ switch s.qFlags {
+ case recvQ:
+ s.ep.updateReceiveMemUsed(-s.segMemSize())
+ case sendQ:
+ // no memory accounting for sendQ yet.
+ default:
+ panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
+ }
+ }
s.route.Release()
}
}
@@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// payloadSize is the size of s.data.
+func (s *segment) payloadSize() int {
+ return s.data.Size()
+}
+
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index 48a257137..54545a1b1 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -22,16 +22,16 @@ import (
//
// +stateify savable
type segmentQueue struct {
- mu sync.Mutex `state:"nosave"`
- list segmentList `state:"wait"`
- limit int
- used int
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ ep *endpoint
+ frozen bool
}
// emptyLocked determines if the queue is empty.
// Preconditions: q.mu must be held.
func (q *segmentQueue) emptyLocked() bool {
- return q.used == 0
+ return q.list.Empty()
}
// empty determines if the queue is empty.
@@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool {
return r
}
-// setLimit updates the limit. No segments are immediately dropped in case the
-// queue becomes full due to the new limit.
-func (q *segmentQueue) setLimit(limit int) {
- q.mu.Lock()
- q.limit = limit
- q.mu.Unlock()
-}
-
// enqueue adds the given segment to the queue.
//
// Returns true when the segment is successfully added to the queue, in which
@@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) {
// false if the queue is full, in which case ownership is retained by the
// caller.
func (q *segmentQueue) enqueue(s *segment) bool {
+ // q.ep.receiveBufferParams() must be called without holding q.mu to
+ // avoid lock order inversion.
+ bufSz := q.ep.receiveBufferSize()
+ used := q.ep.receiveMemUsed()
q.mu.Lock()
- r := q.used < q.limit
- if r {
+ // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue
+ // is currently full).
+ allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen
+
+ if allow {
q.list.PushBack(s)
- q.used++
+ // Set the owner now that the endpoint owns the segment.
+ s.setOwner(q.ep, recvQ)
}
q.mu.Unlock()
- return r
+ return allow
}
// dequeue removes and returns the next segment from queue, if one exists.
@@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment {
s := q.list.Front()
if s != nil {
q.list.Remove(s)
- q.used--
}
q.mu.Unlock()
return s
}
+
+// freeze prevents any more segments from being added to the queue. i.e all
+// future segmentQueue.enqueue will return false and not add the segment to the
+// queue till the queue is unfroze with a corresponding segmentQueue.thaw call.
+func (q *segmentQueue) freeze() {
+ q.mu.Lock()
+ q.frozen = true
+ q.mu.Unlock()
+}
+
+// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and
+// allows new segments to be queued again.
+func (q *segmentQueue) thaw() {
+ q.mu.Lock()
+ q.frozen = false
+ q.mu.Unlock()
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index b1e5f1b24..8326736dc 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -240,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetsSentNoICMP confirms that we don't get an ICMP
+// DstUnreachable packet when we try send a packet which is not part
+// of an active session.
+func TestTCPResetsSentNoICMP(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ stats := c.Stack().Stats()
+
+ // Send a SYN request for a closed port. This should elicit an RST
+ // but NOT an ICMPv4 DstUnreachable packet.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive whatever comes back.
+ b := c.GetPacket()
+ ipHdr := header.IPv4(b)
+ if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want {
+ t.Errorf("unexpected protocol, got = %d, want = %d", got, want)
+ }
+
+ // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
+ sent := stats.ICMP.V4PacketsSent
+ if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
+ t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
+ }
+}
+
// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
// a RST if an ACK is received on the listening socket for which there is no
// active handshake in progress and we are not using SYN cookies.
@@ -317,8 +350,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
SrcPort: context.TestPort,
@@ -348,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -447,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -489,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence number
// of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -529,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -565,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -612,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -633,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -693,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -745,8 +778,8 @@ func TestSimpleReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -998,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
@@ -1026,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
@@ -1063,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
@@ -1101,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestSendRstOnListenerRxAckV4(t *testing.T) {
@@ -1130,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
func TestSendRstOnListenerRxAckV6(t *testing.T) {
@@ -1158,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) {
checker.IPv6(t, c.GetV6Packet(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
+ checker.TCPSeqNum(200)))
}
// TestListenShutdown tests for the listening endpoint replying with RST
@@ -1274,8 +1307,8 @@ func TestTOSV4(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1323,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
checker.TOS(tos, 0),
@@ -1514,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1565,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1576,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Create a new connection with initial window size of 10.
- c.CreateConnected(789, 30000, 10)
+ rcvBufSz := math.MaxUint16
+ c.CreateConnected(789, 30000, rcvBufSz)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
@@ -1598,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1619,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1639,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(793),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1681,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1696,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
// We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1750,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1764,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
@@ -1783,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// RST is always generated with sndNxt which if the FIN
// has been sent will be 1 higher than the sequence
// number of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1829,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, 10)
+ const rcvBufSz = 10
+ c.CreateConnected(789, 30000, rcvBufSz)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1840,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) {
t.Fatalf("Read failed: %s", err)
}
- // Fill up the window.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
+ // the provided buffer value by tcp.SegOverheadFactor to calculate the actual
+ // receive buffer size.
+ data := make([]byte, tcp.SegOverheadFactor*rcvBufSz)
+ for i := range data {
+ data[i] = byte(i % 255)
+ }
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -1862,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
@@ -1888,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+len(data))),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(10),
+ checker.TCPWindow(10),
),
)
}
@@ -1900,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Start off with a window size of 10, then shrink it to 5.
- c.CreateConnected(789, 30000, 10)
-
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
- }
+ // Start off with a certain receive buffer then cut it in half and verify that
+ // the right edge of the window does not shrink.
+ // NOTE: Netstack doubles the value specified here.
+ rcvBufSize := 65536
+ iss := seqnum.Value(789)
+ // Enable window scaling with a scale of zero from our end.
+ c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -1914,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) {
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
-
- // Send 3 bytes, check that the peer acknowledges them.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
- c.SendPacket(data[:3], &context.Headers{
+ // Send a 1 byte payload so that we can record the current receive window.
+ // Send a payload of half the size of rcvBufSize.
+ seqNum := iss.Add(1)
+ payload := []byte{1}
+ c.SendPacket(payload, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 790,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
@@ -1933,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("Timed out waiting for data to arrive")
}
- // Check that data is acknowledged, and that window doesn't go to zero
- // just yet because it was previously set to 10. It must go to 7 now.
- checker.IPv4(t, c.GetPacket(),
+ // Read the 1 byte payload we just sent.
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ if got, want := payload, v; !bytes.Equal(got, want) {
+ t.Fatalf("got data: %v, want: %v", got, want)
+ }
+
+ seqNum = seqNum.Add(1)
+ // Verify that the ACK does not shrink the window.
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(7),
),
)
+ // Stash the initial window.
+ initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd))
+ // Now shrink the receive buffer to half its original size.
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil {
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
+ }
- // Send 7 more bytes, check that the window fills up.
- c.SendPacket(data[3:], &context.Headers{
+ data := generateRandomPayload(t, rcvBufSize)
+ // Send a payload of half the size of rcvBufSize.
+ c.SendPacket(data[:rcvBufSize/2], &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
- SeqNum: 793,
+ SeqNum: seqNum,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
+ // Verify that the ACK does not shrink the window.
+ pkt = c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
+ newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd))
+ if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
+ t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
}
+ // Send another payload of half the size of rcvBufSize. This should fill up the
+ // socket receive buffer and we should see a zero window.
+ c.SendPacket(data[rcvBufSize/2:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqNum,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2))
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
+ checker.TCPWindow(0),
),
)
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
// Receive data and check it.
- read := make([]byte, 0, 10)
+ read := make([]byte, 0, rcvBufSize)
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
@@ -1986,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) {
t.Fatalf("got data = %v, want = %v", read, data)
}
- // Check that we get an ACK for the newly non-zero window, which is the
- // new size.
+ // Check that we get an ACK for the newly non-zero window, which is the new
+ // receive buffer size we set after the connection was established.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(seqNum)),
checker.TCPFlags(header.TCPFlagAck),
- checker.Window(5),
+ checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
),
)
}
@@ -2019,8 +2109,8 @@ func TestSimpleSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2061,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2083,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2123,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) {
t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2162,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2198,7 +2288,8 @@ func TestScaledWindowAccept(t *testing.T) {
}
// Do 3-way handshake.
- c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -2228,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) {
t.Fatalf("Write failed: %s", err)
}
- // Check that data is received, and that advertised window is 0xbfff,
+ // Check that data is received, and that advertised window is 0x5fff,
// that is, that it is scaled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0x5fff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2309,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(0xffff),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2324,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- // Set the window size such that a window scale of 4 will be used.
- const wnd = 65535 * 10
- const ws = uint32(4)
- c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
+ // Set the buffer size such that a window scale of 5 will be used.
+ const bufSz = 65535 * 10
+ const ws = uint32(5)
+ c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
// Write chunks of 50000 bytes.
- remain := wnd
+ remain := 0
sent := 0
data := make([]byte, 50000)
- for remain > len(data) {
+ // Keep writing till the window drops below len(data).
+ for {
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -2345,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(remain>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Don't reduce window to zero here.
+ if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) {
+ remain = wnd << ws
+ break
+ }
}
// Make the window non-zero, but the scaled window zero.
- if remain >= 16 {
+ for remain >= 16 {
data = data[:remain-15]
c.SendPacket(data, &context.Headers{
SrcPort: context.TestPort,
@@ -2370,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) {
RcvWnd: 30000,
})
sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(0),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Since the receive buffer is split between window advertisement and
+ // application data buffer the window does not always reflect the space
+ // available and actual space available can be a bit more than what is
+ // advertised in the window.
+ wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize())
+ if wnd == 0 {
+ break
+ }
+ remain = wnd << ws
}
- // Read at least 1MSS of data. An ack should be sent in response to that.
+ // Read at least 2MSS of data. An ack should be sent in response to that.
+ // Since buffer space is now split in half between window and application
+ // data we need to read more than 1 MSS(65536) of data for a non-zero window
+ // update to be sent. For 1MSS worth of window to be available we need to
+ // read at least 128KB. Since our segments above were 50KB each it means
+ // we need to read at 3 packets.
sz := 0
- for sz < defaultMTU {
+ for sz < defaultMTU*2 {
v, _, err := c.EP.Read(nil)
if err != nil {
t.Fatalf("Read failed: %s", err)
@@ -2397,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(sz>>ws)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -2466,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize+1),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2489,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+11),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+11),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2537,8 +2646,8 @@ func TestDelay(t *testing.T) {
checker.PayloadLen(len(want)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2584,8 +2693,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2607,8 +2716,8 @@ func TestUndelay(t *testing.T) {
checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2669,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(seq)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2721,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -2964,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- const wndScale = 2
+ const wndScale = 3
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
@@ -2999,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
checker.SrcPort(tcpHdr.SourcePort()),
- checker.SeqNum(tcpHdr.SequenceNumber()),
+ checker.TCPSeqNum(tcpHdr.SequenceNumber()),
checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -3020,8 +3129,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
@@ -3314,8 +3423,8 @@ func TestFinImmediately(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3335,8 +3444,8 @@ func TestFinImmediately(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3357,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3368,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3389,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3413,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3438,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3460,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3488,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3507,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3527,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3548,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3572,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3597,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3613,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3634,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3659,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3680,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3695,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3711,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(791),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -3803,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -3942,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3993,8 +4102,8 @@ func TestReadAfterClosedState(t *testing.T) {
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
),
)
@@ -4018,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(uint32(791+len(data))),
+ checker.TCPSeqNum(uint32(c.IRS)+2),
+ checker.TCPAckNum(uint32(791+len(data))),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4280,14 +4389,14 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
}
- // Set values below the min.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
+ // Set values below the min/2.
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
}
@@ -4298,13 +4407,15 @@ func TestMinMaxBufferSizes(t *testing.T) {
t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
+ // Values above max are capped at max and then doubled.
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
+ // Values above max are capped at max and then doubled.
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
}
func TestBindToDeviceOption(t *testing.T) {
@@ -4646,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(seqNum),
- checker.AckNum(790),
+ checker.TCPSeqNum(seqNum),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4898,8 +5009,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4932,8 +5043,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -4944,8 +5055,8 @@ func TestKeepalive(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
),
)
@@ -4970,8 +5081,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next-1)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(next-1)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -4997,8 +5108,8 @@ func TestKeepalive(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -5038,7 +5149,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -5082,7 +5193,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo
checker.SrcPort(context.StackPort),
checker.DstPort(srcPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
if synCookieInUse {
@@ -5316,7 +5427,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5416,7 +5527,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
})
}
}
@@ -5464,7 +5575,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5642,7 +5753,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
+ checker.TCPAckNum(uint32(irs) + 1),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5663,8 +5774,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.AckNum(uint32(irs) + 1),
- checker.SeqNum(uint32(iss + 1)),
+ checker.TCPAckNum(uint32(irs) + 1),
+ checker.TCPSeqNum(uint32(iss + 1)),
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
@@ -5962,16 +6073,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
time.Sleep(latency)
rawEP.SendPacketWithTS([]byte{1}, tsVal)
- // Verify that the ACK has the expected window.
- wantRcvWnd := receiveBufferSize
- wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
- rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
time.Sleep(25 * time.Millisecond)
// Allocate a large enough payload for the test.
- b := make([]byte, int(receiveBufferSize)*2)
- offset := 0
- payloadSize := receiveBufferSize - 1
+ payloadSize := receiveBufferSize * 2
+ b := make([]byte, int(payloadSize))
+
worker := (c.EP).(interface {
StopWork()
ResumeWork()
@@ -5980,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Stop the worker goroutine.
worker.StopWork()
- start := offset
- end := offset + payloadSize
+ start := 0
+ end := payloadSize / 2
packetsSent := 0
for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ packetEnd := start + mss
+ if start+mss > end {
+ packetEnd = end
+ }
+ rawEP.SendPacketWithTS(b[start:packetEnd], tsVal)
packetsSent++
}
@@ -5992,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// are waiting to be read.
worker.ResumeWork()
- // Since we read no bytes the window should goto zero till the
- // application reads some of the data.
- // Discard all intermediate acks except the last one.
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
- }
+ // Since we sent almost the full receive buffer worth of data (some may have
+ // been dropped due to segment overheads), we should get a zero window back.
+ pkt = c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(pkt).Payload())
+ gotRcvWnd := tcpHdr.WindowSize()
+ wantAckNum := tcpHdr.AckNumber()
+ if got, want := int(gotRcvWnd), 0; got != want {
+ t.Fatalf("got rcvWnd: %d, want: %d", got, want)
}
- rawEP.VerifyACKRcvWnd(0)
time.Sleep(25 * time.Millisecond)
- // Verify that sending more data when window is closed is dropped and
- // not acked.
+ // Verify that sending more data when receiveBuffer is exhausted.
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- // Verify that the stack sends us back an ACK with the sequence number
- // of the last packet sent indicating it was dropped.
- p := c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
- checker.Window(0),
- ))
-
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
@@ -6027,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Verify that we receive a non-zero window update ACK. When running
// under thread santizer this test can end up sending more than 1
// ack, 1 for the non-zero window
- p = c.GetPacket()
+ p := c.GetPacket()
checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ checker.TCPAckNum(uint32(wantAckNum)),
func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
- if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
+ // We use 10% here as the error margin upwards as the initial window we
+ // got was afer 1 segment was already in the receive buffer queue.
+ tolerance := 1.1
+ if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) {
+ t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance))
}
},
))
}
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
+// This test verifies that the advertised window is auto-tuned up as the
+// application is reading the data that is being received.
func TestReceiveBufferAutoTuning(t *testing.T) {
const mtu = 1500
const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
@@ -6053,9 +6160,6 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Enable Auto-tuning.
stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
{
@@ -6077,8 +6181,10 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
-
- wantRcvWnd := receiveBufferSize
+ tsVal := uint32(rawEP.TSVal)
+ rawEP.SendPacketWithTS([]byte{1}, tsVal)
+ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
+ curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
scaleRcvWnd := func(rcvWnd int) uint16 {
return uint16(rcvWnd >> uint16(c.WindowScale))
}
@@ -6095,14 +6201,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
StopWork()
ResumeWork()
})
- tsVal := rawEP.TSVal
- // We are going to do our own computation of what the moderated receive
- // buffer should be based on sent/copied data per RTT and verify that
- // the advertised window by the stack matches our calculations.
- prevCopied := 0
- done := false
latency := 1 * time.Millisecond
- for i := 0; !done; i++ {
+ for i := 0; i < 5; i++ {
tsVal++
// Stop the worker goroutine.
@@ -6124,15 +6224,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Give 1ms for the worker to process the packets.
time.Sleep(1 * time.Millisecond)
- // Verify that the advertised window on the ACK is reduced by
- // the total bytes sent.
- expectedWnd := wantRcvWnd - totalSent
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
}
+ lastACK = pkt
+ }
+ if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want {
+ t.Fatalf("advertised window got: %d, want <= %d", got, want)
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
// Now read all the data from the endpoint and invoke the
// moderation API to allow for receive buffer auto-tuning
@@ -6157,35 +6262,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
-
if i == 0 {
// In the first iteration the receiver based RTT is not
// yet known as a result the moderation code should not
// increase the advertised window.
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
- prevCopied = totalCopied
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
} else {
- rttCopied := totalCopied
- if i == 1 {
- // The moderation code accumulates copied bytes till
- // RTT is established. So add in the bytes sent in
- // the first iteration to the total bytes for this
- // RTT.
- rttCopied += prevCopied
- // Now reset it to the initial value used by the
- // auto tuning logic.
- prevCopied = tcp.InitialCwnd * mss * 2
- }
- newWnd := rttCopied<<1 + 16*mss
- grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
- newWnd += (grow << 1)
- if newWnd > maxReceiveBufferSize {
- newWnd = maxReceiveBufferSize
- done = true
+ pkt := c.GetPacket()
+ curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
+ // If thew new current window is close maxReceiveBufferSize then terminate
+ // the loop. This can happen before all iterations are done due to timing
+ // differences when running the test.
+ if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 {
+ break
}
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
- wantRcvWnd = newWnd
- prevCopied = rttCopied
// Increase the latency after first two iterations to
// establish a low RTT value in the receiver since it
// only tracks the lowest value. This ensures that when
@@ -6198,6 +6288,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
offset += payloadSize
payloadSize *= 2
}
+ // Check that at the end of our iterations the receive window grew close to the maximum
+ // permissible size of maxReceiveBufferSize/2
+ if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want {
+ t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want)
+ }
+
}
func TestDelayEnabled(t *testing.T) {
@@ -6349,8 +6445,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6367,8 +6463,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now send a RST and this should be ignored and not
@@ -6396,8 +6492,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6468,8 +6564,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6486,8 +6582,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Out of order ACK should generate an immediate ACK in
@@ -6503,8 +6599,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
}
@@ -6575,8 +6671,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6593,8 +6689,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Send a SYN request w/ sequence number lower than
@@ -6732,8 +6828,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
finHeaders := &context.Headers{
@@ -6750,8 +6846,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
time.Sleep(2 * time.Second)
@@ -6765,8 +6861,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+2)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Sleep for 4 seconds so at this point we are 1 second past the
@@ -6794,8 +6890,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
@@ -6894,8 +6990,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+2),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(iss)+2),
checker.TCPFlags(header.TCPFlagAck)))
// Now write a few bytes and then close the endpoint.
@@ -6913,8 +7009,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -6928,8 +7024,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
- checker.AckNum(uint32(iss+2)),
+ checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.TCPAckNum(uint32(iss+2)),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
// First send a partial ACK.
@@ -6974,8 +7070,8 @@ func TestTCPCloseWithData(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
+ checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
+ checker.TCPAckNum(0),
checker.TCPFlags(header.TCPFlagRst)))
}
@@ -7011,8 +7107,8 @@ func TestTCPUserTimeout(t *testing.T) {
checker.PayloadLen(len(view)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
+ checker.TCPSeqNum(next),
+ checker.TCPAckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -7046,8 +7142,8 @@ func TestTCPUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(next)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -7108,8 +7204,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
+ checker.TCPSeqNum(uint32(c.IRS)),
+ checker.TCPAckNum(uint32(790)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7134,8 +7230,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(0)),
+ checker.TCPSeqNum(uint32(c.IRS+1)),
+ checker.TCPAckNum(uint32(0)),
checker.TCPFlags(header.TCPFlagRst),
),
)
@@ -7151,9 +7247,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
}
}
-func TestIncreaseWindowOnReceive(t *testing.T) {
+func TestIncreaseWindowOnRead(t *testing.T) {
// This test ensures that the endpoint sends an ack,
- // after recv() when the window grows to more than 1 MSS.
+ // after read() when the window grows by more than 1 MSS.
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -7162,10 +7258,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf
+ remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -7178,46 +7273,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
- checker.IPv4(t, c.GetPacket(),
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
checker.TCPFlags(header.TCPFlagAck),
),
)
+ // Break once the window drops below defaultMTU/2
+ if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 {
+ break
+ }
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
- // We now have < 1 MSS in the buffer space. Read the data! An
- // ack should be sent in response to that. The window was not
- // zero, but it grew to larger than MSS.
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %s", err)
+ // We now have < 1 MSS in the buffer space. Read at least > 2 MSS
+ // worth of data as receive buffer space
+ read := 0
+ // defaultMTU is a good enough estimate for the MSS used for this
+ // connection.
+ for read < defaultMTU*2 {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ read += len(v)
}
- // After reading two packets, we surely crossed MSS. See the ack:
+ // After reading > MSS worth of data, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7234,10 +7326,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf
+ remain := rcvBuf * 2
sent := 0
data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
for remain > len(data) {
c.SendPacket(data, &context.Headers{
@@ -7251,38 +7342,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
sent += len(data)
remain -= len(data)
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindowLessThanEq(0xffff),
checker.TCPFlags(header.TCPFlagAck),
),
)
}
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
- // After reading two packets, we surely crossed MSS. See the ack:
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(790+sent)),
+ checker.TCPWindow(uint16(0xffff)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -7327,8 +7409,8 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give a bit of time for the socket to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
@@ -7342,8 +7424,8 @@ func TestTCPDeferAccept(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestTCPDeferAcceptTimeout(t *testing.T) {
@@ -7380,7 +7462,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
+ checker.TCPAckNum(uint32(irs)+1)))
// Send data. This should result in an acceptable endpoint.
c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
@@ -7396,8 +7478,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
// Give sometime for the endpoint to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
@@ -7412,8 +7494,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
+ checker.TCPSeqNum(uint32(iss+1)),
+ checker.TCPAckNum(uint32(irs+5))))
}
func TestResetDuringClose(t *testing.T) {
@@ -7438,8 +7520,8 @@ func TestResetDuringClose(t *testing.T) {
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(irs.Add(1))),
- checker.AckNum(uint32(iss.Add(5)))))
+ checker.TCPSeqNum(uint32(irs.Add(1))),
+ checker.TCPAckNum(uint32(iss.Add(5)))))
// Close in a separate goroutine so that we can trigger
// a race with the RST we send below. This should not
@@ -7520,3 +7602,14 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
}
}
+
+// generateRandomPayload generates a random byte slice of the specified length
+// causing a fatal test failure if it is unable to do so.
+func generateRandomPayload(t *testing.T, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return buf
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 44593ed98..0f9ed06cd 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -159,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(true, 0, tsVal+1),
),
@@ -181,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -219,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(790),
+ checker.TCPWindow(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(false, 0, 0),
),
@@ -237,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) {
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
+ // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of
+ // that.
+ {false, 5, 0x4000},
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 85e8c1c75..ebbae6e2f 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -145,6 +145,10 @@ type Context struct {
// WindowScale is the expected window scale in SYN packets sent by
// the stack.
WindowScale uint8
+
+ // RcvdWindowScale is the actual window scale sent by the stack in
+ // SYN/SYN-ACK.
+ RcvdWindowScale uint8
}
// New allocates and initializes a test context containing a new
@@ -261,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) {
c.CheckNoPacketTimeout(errMsg, 1*time.Second)
}
-// GetPacket reads a packet from the link layer endpoint and verifies
+// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies
// that it is an IPv4 packet with the expected source and destination
-// addresses. It will fail with an error if no packet is received for
-// 2 seconds.
-func (c *Context) GetPacket() []byte {
+// addresses. If no packet is received in the specified timeout it will return
+// nil.
+func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
c.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
p, ok := c.linkEP.ReadContext(ctx)
if !ok {
- c.t.Fatalf("Packet wasn't written out")
return nil
}
@@ -280,6 +283,14 @@ func (c *Context) GetPacket() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()
@@ -291,6 +302,21 @@ func (c *Context) GetPacket() []byte {
return b
}
+// GetPacket reads a packet from the link layer endpoint and verifies
+// that it is an IPv4 packet with the expected source and destination
+// addresses.
+func (c *Context) GetPacket() []byte {
+ c.t.Helper()
+
+ p := c.GetPacketWithTimeout(5 * time.Second)
+ if p == nil {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ return p
+}
+
// GetPacketNonBlocking reads a packet from the link layer endpoint
// and verifies that it is an IPv4 packet with the expected source
// and destination address. If no packet is available it will return
@@ -307,6 +333,14 @@ func (c *Context) GetPacketNonBlocking() []byte {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
+ // Just check that the stack set the transport protocol number for outbound
+ // TCP messages.
+ // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
+ // of the headerinfo.
+ if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
+ c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
+ }
+
vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()
@@ -470,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.PayloadLen(size+header.TCPMinimumSize+optlen),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -497,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -636,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
+ synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */)
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
c.SendPacket(nil, &Headers{
@@ -653,8 +688,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
+ checker.TCPSeqNum(uint32(c.IRS)+1),
+ checker.TCPAckNum(uint32(iss)+1),
),
)
@@ -671,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ c.RcvdWindowScale = uint8(synOpts.WS)
c.Port = tcpHdr.SourcePort()
}
@@ -742,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
}
-// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
-// tsVal.
-func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches
+// the provided tsVal as well as returns the original packet.
+func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte {
+ r.C.t.Helper()
// Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPTimestampChecker(true, 0, tsVal),
),
)
@@ -760,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
opts := tcpSeg.ParsedOptions()
r.RecentTS = opts.TSVal
+ return ackPacket
+}
+
+// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
+// tsVal.
+func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+ r.C.t.Helper()
+ _ = r.VerifyAndReturnACKWithTS(tsVal)
}
// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
// matches the provided rcvWnd.
func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
+ r.C.t.Helper()
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
- checker.Window(rcvWnd),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
+ checker.TCPWindow(rcvWnd),
),
)
}
@@ -791,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSeqNum(uint32(r.AckNum)),
+ checker.TCPAckNum(uint32(r.NextSeqNum)),
checker.TCPSACKBlockChecker(sackBlocks),
),
)
@@ -884,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
tcpCheckers := []checker.TransportChecker{
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS) + 1),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPSeqNum(uint32(c.IRS) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
}
// Verify that tsEcr of ACK packet is wantOptions.TSVal if the
@@ -920,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Mark in context that timestamp option is enabled for this endpoint.
c.TimeStampEnabled = true
-
+ c.RcvdWindowScale = uint8(synOptions.WS)
return &RawEndpoint{
C: c,
SrcPort: tcpSeg.DestinationPort(),
@@ -1013,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
offset += header.EncodeMSSOption(uint32(maxPayload), opts)
@@ -1051,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// are present.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
+ rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */)
c.IRS = seqnum.Value(tcp.SequenceNumber())
tcpCheckers := []checker.TransportChecker{
checker.SrcPort(StackPort),
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(iss) + 1),
+ checker.TCPAckNum(uint32(iss) + 1),
checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
}
@@ -1100,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// Send ACK.
c.SendPacket(nil, ackHeaders)
+ c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
c.Port = StackPort
return &RawEndpoint{
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 518f636f0..086d0bdbc 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -996,6 +996,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// Initialize the UDP header.
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 7d6b91a75..a1d0f49d9 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -80,126 +80,21 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
-// HandleUnknownDestinationPacket handles packets targeted at this protocol but
-// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+// HandleUnknownDestinationPacket handles packets that are targeted at this
+// protocol but don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
- // Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
- return true
+ return stack.UnknownDestinationPacketMalformed
}
if !verifyChecksum(r, hdr, pkt) {
- // Checksum Error.
r.Stack().Stats().UDP.ChecksumErrors.Increment()
- return true
+ return stack.UnknownDestinationPacketMalformed
}
- // Only send ICMP error if the address is not a multicast/broadcast
- // v4/v6 address or the source is not the unspecified address.
- //
- // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
- return true
- }
-
- // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
- // Unreachable messages with code:
- //
- // 2 (Protocol Unreachable), when the designated transport protocol
- // is not supported; or
- //
- // 3 (Port Unreachable), when the designated transport protocol
- // (e.g., UDP) is unable to demultiplex the datagram but has no
- // protocol mechanism to inform the sender.
- switch len(id.LocalAddress) {
- case header.IPv4AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
- return true
- }
- // As per RFC 1812 Section 4.3.2.3
- //
- // ICMP datagram SHOULD contain as much of the original
- // datagram as possible without the length of the ICMP
- // datagram exceeding 576 bytes
- //
- // NOTE: The above RFC referenced is different from the original
- // recommendation in RFC 1122 where it mentioned that at least 8
- // bytes of the payload must be included. Today linux and other
- // systems implement the] RFC1812 definition and not the original
- // RFC 1122 requirement.
- mtu := int(r.MTU())
- if mtu > header.IPv4MinimumProcessableDatagramSize {
- mtu = header.IPv4MinimumProcessableDatagramSize
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
- available := int(mtu) - headerLen
- payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
-
- // The buffers used by pkt may be used elsewhere in the system.
- // For example, a raw or packet socket may use what UDP
- // considers an unreachable destination. Thus we deep copy pkt
- // to prevent multiple ownership and SR errors.
- newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...)
- newHeader = append(newHeader, pkt.TransportHeader().View()...)
- payload := newHeader.ToVectorisedView()
- payload.AppendView(pkt.Data.ToView())
- payload.CapLength(payloadLen)
-
- icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
- Data: payload,
- })
- icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
-
- case header.IPv6AddressSize:
- if !r.Stack().AllowICMPMessage() {
- r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
- return true
- }
-
- // As per RFC 4443 section 2.4
- //
- // (c) Every ICMPv6 error message (type < 128) MUST include
- // as much of the IPv6 offending (invoking) packet (the
- // packet that caused the error) as possible without making
- // the error message packet exceed the minimum IPv6 MTU
- // [IPv6].
- mtu := int(r.MTU())
- if mtu > header.IPv6MinimumMTU {
- mtu = header.IPv6MinimumMTU
- }
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
- available := int(mtu) - headerLen
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- payloadLen := len(network) + len(transport) + pkt.Data.Size()
- if payloadLen > available {
- payloadLen = available
- }
- payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport})
- payload.Append(pkt.Data)
- payload.CapLength(payloadLen)
-
- icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
- Data: payload,
- })
- icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data))
- r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
- }
- return true
+ return stack.UnknownDestinationPacketUnhandled
}
// SetOption implements stack.TransportProtocol.SetOption.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index d5881d183..64a5fc696 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -388,6 +388,10 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
+ if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
+ c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
+ }
+
vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()