diff options
Diffstat (limited to 'pkg/tcpip')
52 files changed, 1844 insertions, 1144 deletions
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 563bc78ea..c326fab54 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -14,6 +14,8 @@ go_library( go_test( name = "buffer_test", size = "small", - srcs = ["view_test.go"], + srcs = [ + "view_test.go", + ], library = ":buffer", ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index b769094dc..19627fa9b 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -339,7 +339,7 @@ func NoChecksum(noChecksum bool) TransportChecker { udp, ok := h.(header.UDP) if !ok { - return + t.Fatalf("UDP header not found in h: %T", h) } if b := udp.Checksum() == 0; b != noChecksum { @@ -348,14 +348,14 @@ func NoChecksum(noChecksum bool) TransportChecker { } } -// SeqNum creates a checker that checks the sequence number. -func SeqNum(seq uint32) TransportChecker { +// TCPSeqNum creates a checker that checks the sequence number. +func TCPSeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if s := tcp.SequenceNumber(); s != seq { @@ -364,14 +364,14 @@ func SeqNum(seq uint32) TransportChecker { } } -// AckNum creates a checker that checks the ack number. -func AckNum(seq uint32) TransportChecker { +// TCPAckNum creates a checker that checks the ack number. +func TCPAckNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if s := tcp.AckNumber(); s != seq { @@ -380,18 +380,52 @@ func AckNum(seq uint32) TransportChecker { } } -// Window creates a checker that checks the tcp window. -func Window(window uint16) TransportChecker { +// TCPWindow creates a checker that checks the tcp window. +func TCPWindow(window uint16) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in hdr : %T", h) } if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) + t.Errorf("Bad window, got %d, want %d", w, window) + } + } +} + +// TCPWindowGreaterThanEq creates a checker that checks that the TCP window +// is greater than or equal to the provided value. +func TCPWindowGreaterThanEq(window uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + t.Fatalf("TCP header not found in h: %T", h) + } + + if w := tcp.WindowSize(); w < window { + t.Errorf("Bad window, got %d, want > %d", w, window) + } + } +} + +// TCPWindowLessThanEq creates a checker that checks that the tcp window +// is less than or equal to the provided value. +func TCPWindowLessThanEq(window uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + t.Fatalf("TCP header not found in h: %T", h) + } + + if w := tcp.WindowSize(); w > window { + t.Errorf("Bad window, got %d, want < %d", w, window) } } } @@ -403,7 +437,7 @@ func TCPFlags(flags uint8) TransportChecker { tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if f := tcp.Flags(); f != flags { @@ -420,7 +454,7 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { tcp, ok := h.(header.TCP) if !ok { - return + t.Fatalf("TCP header not found in h: %T", h) } if f := tcp.Flags(); (f & mask) != (flags & mask) { diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD new file mode 100644 index 000000000..114d43df3 --- /dev/null +++ b/pkg/tcpip/faketime/BUILD @@ -0,0 +1,24 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "faketime", + srcs = ["faketime.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "@com_github_dpjacques_clockwork//:go_default_library", + ], +) + +go_test( + name = "faketime_test", + size = "small", + srcs = [ + "faketime_test.go", + ], + deps = [ + "//pkg/tcpip/faketime", + ], +) diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/faketime/faketime.go index 92c8cb534..1193f1d7d 100644 --- a/pkg/tcpip/stack/fake_time_test.go +++ b/pkg/tcpip/faketime/faketime.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +// Package faketime provides a fake clock that implements tcpip.Clock interface. +package faketime import ( "container/heap" @@ -23,7 +24,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -type fakeClock struct { +// ManualClock implements tcpip.Clock and only advances manually with Advance +// method. +type ManualClock struct { clock clockwork.FakeClock // mu protects the fields below. @@ -39,34 +42,35 @@ type fakeClock struct { waitGroups map[time.Time]*sync.WaitGroup } -func newFakeClock() *fakeClock { - return &fakeClock{ +// NewManualClock creates a new ManualClock instance. +func NewManualClock() *ManualClock { + return &ManualClock{ clock: clockwork.NewFakeClock(), times: &timeHeap{}, waitGroups: make(map[time.Time]*sync.WaitGroup), } } -var _ tcpip.Clock = (*fakeClock)(nil) +var _ tcpip.Clock = (*ManualClock)(nil) // NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (fc *fakeClock) NowNanoseconds() int64 { - return fc.clock.Now().UnixNano() +func (mc *ManualClock) NowNanoseconds() int64 { + return mc.clock.Now().UnixNano() } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (fc *fakeClock) NowMonotonic() int64 { - return fc.NowNanoseconds() +func (mc *ManualClock) NowMonotonic() int64 { + return mc.NowNanoseconds() } // AfterFunc implements tcpip.Clock.AfterFunc. -func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - until := fc.clock.Now().Add(d) - wg := fc.addWait(until) - return &fakeTimer{ - clock: fc, +func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := mc.clock.Now().Add(d) + wg := mc.addWait(until) + return &manualTimer{ + clock: mc, until: until, - timer: fc.clock.AfterFunc(d, func() { + timer: mc.clock.AfterFunc(d, func() { defer wg.Done() f() }), @@ -75,110 +79,113 @@ func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { // addWait adds an additional wait to the WaitGroup for parallel execution of // all work scheduled for t. Returns a reference to the WaitGroup modified. -func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { - fc.mu.RLock() - wg, ok := fc.waitGroups[t] - fc.mu.RUnlock() +func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup { + mc.mu.RLock() + wg, ok := mc.waitGroups[t] + mc.mu.RUnlock() if ok { wg.Add(1) return wg } - fc.mu.Lock() - heap.Push(fc.times, t) - fc.mu.Unlock() + mc.mu.Lock() + heap.Push(mc.times, t) + mc.mu.Unlock() wg = &sync.WaitGroup{} wg.Add(1) - fc.mu.Lock() - fc.waitGroups[t] = wg - fc.mu.Unlock() + mc.mu.Lock() + mc.waitGroups[t] = wg + mc.mu.Unlock() return wg } // removeWait removes a wait from the WaitGroup for parallel execution of all // work scheduled for t. -func (fc *fakeClock) removeWait(t time.Time) { - fc.mu.RLock() - defer fc.mu.RUnlock() +func (mc *ManualClock) removeWait(t time.Time) { + mc.mu.RLock() + defer mc.mu.RUnlock() - wg := fc.waitGroups[t] + wg := mc.waitGroups[t] wg.Done() } -// advance executes all work that have been scheduled to execute within d from -// the current fake time. Blocks until all work has completed execution. -func (fc *fakeClock) advance(d time.Duration) { +// Advance executes all work that have been scheduled to execute within d from +// the current time. Blocks until all work has completed execution. +func (mc *ManualClock) Advance(d time.Duration) { // Block until all the work is done - until := fc.clock.Now().Add(d) + until := mc.clock.Now().Add(d) for { - fc.mu.Lock() - if fc.times.Len() == 0 { - fc.mu.Unlock() - return + mc.mu.Lock() + if mc.times.Len() == 0 { + mc.mu.Unlock() + break } - t := heap.Pop(fc.times).(time.Time) + t := heap.Pop(mc.times).(time.Time) if t.After(until) { // No work to do - heap.Push(fc.times, t) - fc.mu.Unlock() - return + heap.Push(mc.times, t) + mc.mu.Unlock() + break } - fc.mu.Unlock() + mc.mu.Unlock() - diff := t.Sub(fc.clock.Now()) - fc.clock.Advance(diff) + diff := t.Sub(mc.clock.Now()) + mc.clock.Advance(diff) - fc.mu.RLock() - wg := fc.waitGroups[t] - fc.mu.RUnlock() + mc.mu.RLock() + wg := mc.waitGroups[t] + mc.mu.RUnlock() wg.Wait() - fc.mu.Lock() - delete(fc.waitGroups, t) - fc.mu.Unlock() + mc.mu.Lock() + delete(mc.waitGroups, t) + mc.mu.Unlock() + } + if now := mc.clock.Now(); until.After(now) { + mc.clock.Advance(until.Sub(now)) } } -type fakeTimer struct { - clock *fakeClock +type manualTimer struct { + clock *ManualClock timer clockwork.Timer mu sync.RWMutex until time.Time } -var _ tcpip.Timer = (*fakeTimer)(nil) +var _ tcpip.Timer = (*manualTimer)(nil) // Reset implements tcpip.Timer.Reset. -func (ft *fakeTimer) Reset(d time.Duration) { - if !ft.timer.Reset(d) { +func (t *manualTimer) Reset(d time.Duration) { + if !t.timer.Reset(d) { return } - ft.mu.Lock() - defer ft.mu.Unlock() + t.mu.Lock() + defer t.mu.Unlock() - ft.clock.removeWait(ft.until) - ft.until = ft.clock.clock.Now().Add(d) - ft.clock.addWait(ft.until) + t.clock.removeWait(t.until) + t.until = t.clock.clock.Now().Add(d) + t.clock.addWait(t.until) } // Stop implements tcpip.Timer.Stop. -func (ft *fakeTimer) Stop() bool { - if !ft.timer.Stop() { +func (t *manualTimer) Stop() bool { + if !t.timer.Stop() { return false } - ft.mu.RLock() - defer ft.mu.RUnlock() + t.mu.RLock() + defer t.mu.RUnlock() - ft.clock.removeWait(ft.until) + t.clock.removeWait(t.until) return true } diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go new file mode 100644 index 000000000..c2704df2c --- /dev/null +++ b/pkg/tcpip/faketime/faketime_test.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package faketime_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" +) + +func TestManualClockAdvance(t *testing.T) { + const timeout = time.Millisecond + clock := faketime.NewManualClock() + start := clock.NowMonotonic() + clock.Advance(timeout) + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + t.Errorf("got = %d, want = %d", got, want) + } +} + +func TestManualClockAfterFunc(t *testing.T) { + const ( + timeout1 = time.Millisecond // timeout for counter1 + timeout2 = 2 * time.Millisecond // timeout for counter2 + ) + tests := []struct { + name string + advance time.Duration + wantCounter1 int + wantCounter2 int + }{ + { + name: "before timeout1", + advance: timeout1 - 1, + wantCounter1: 0, + wantCounter2: 0, + }, + { + name: "timeout1", + advance: timeout1, + wantCounter1: 1, + wantCounter2: 0, + }, + { + name: "timeout2", + advance: timeout2, + wantCounter1: 1, + wantCounter2: 1, + }, + { + name: "after timeout2", + advance: timeout2 + 1, + wantCounter1: 1, + wantCounter2: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + counter1 := 0 + counter2 := 0 + clock.AfterFunc(timeout1, func() { + counter1++ + }) + clock.AfterFunc(timeout2, func() { + counter2++ + }) + start := clock.NowMonotonic() + clock.Advance(test.advance) + if got, want := counter1, test.wantCounter1; got != want { + t.Errorf("got counter1 = %d, want = %d", got, want) + } + if got, want := counter2, test.wantCounter2; got != want { + t.Errorf("got counter2 = %d, want = %d", got, want) + } + if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + t.Errorf("got elapsed = %d, want = %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index be03fb086..c00bcadfb 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -31,6 +31,27 @@ const ( // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. ICMPv4MinimumSize = 8 + // ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an + // errant packet's transport layer that an ICMP error type packet should + // attempt to send as per RFC 792 (see each type) and RFC 1122 + // section 3.2.2 which states: + // Every ICMP error message includes the Internet header and at + // least the first 8 data octets of the datagram that triggered + // the error; more than 8 octets MAY be sent; this header and data + // MUST be unchanged from the received datagram. + // + // RFC 792 shows: + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Code | Checksum | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | unused | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Internet Header + 64 bits of Original Data Datagram | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ICMPv4MinimumErrorPayloadSize = 8 + // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 @@ -39,15 +60,19 @@ const ( icmpv4ChecksumOffset = 2 // icmpv4MTUOffset is the offset of the MTU field - // in a ICMPv4FragmentationNeeded message. + // in an ICMPv4FragmentationNeeded message. icmpv4MTUOffset = 6 // icmpv4IdentOffset is the offset of the ident field - // in a ICMPv4EchoRequest/Reply message. + // in an ICMPv4EchoRequest/Reply message. icmpv4IdentOffset = 4 + // icmpv4PointerOffset is the offset of the pointer field + // in an ICMPv4ParamProblem message. + icmpv4PointerOffset = 4 + // icmpv4SequenceOffset is the offset of the sequence field - // in a ICMPv4EchoRequest/Reply message. + // in an ICMPv4EchoRequest/Reply message. icmpv4SequenceOffset = 6 ) @@ -72,15 +97,23 @@ const ( ICMPv4InfoReply ICMPv4Type = 16 ) +// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. +const ( + ICMPv4TTLExceeded ICMPv4Code = 0 +) + // ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. const ( - ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4NetUnreachable ICMPv4Code = 0 ICMPv4HostUnreachable ICMPv4Code = 1 ICMPv4ProtoUnreachable ICMPv4Code = 2 ICMPv4PortUnreachable ICMPv4Code = 3 ICMPv4FragmentationNeeded ICMPv4Code = 4 ) +// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed. +const ICMPv4UnusedCode ICMPv4Code = 0 + // Type is the ICMP type field. func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } @@ -93,6 +126,15 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } +// SetPointer sets the pointer field in a Parameter error packet. +// This is the first byte of the type specific data field. +func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } + +// SetTypeSpecific sets the full 32 bit type specific data field. +func (b ICMPv4) SetTypeSpecific(val uint32) { + binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val) +} + // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 20b01d8f4..4eb5abd79 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -54,9 +54,17 @@ const ( // address. ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize - // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. + // ICMPv6EchoMinimumSize is the minimum size of a valid echo packet. ICMPv6EchoMinimumSize = 8 + // ICMPv6ErrorHeaderSize is the size of an ICMP error packet header, + // as per RFC 4443, Apendix A, item 4 and the errata. + // ... all ICMP error messages shall have exactly + // 32 bits of type-specific data, so that receivers can reliably find + // the embedded invoking packet even when they don't recognize the + // ICMP message Type. + ICMPv6ErrorHeaderSize = 8 + // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP // destination unreachable packet. ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize @@ -69,6 +77,10 @@ const ( // in an ICMPv6 message. icmpv6ChecksumOffset = 2 + // icmpv6PointerOffset is the offset of the pointer + // in an ICMPv6 Parameter problem message. + icmpv6PointerOffset = 4 + // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6 // PacketTooBig message. icmpv6MTUOffset = 4 @@ -89,9 +101,10 @@ const ( NDPHopLimit = 255 ) -// ICMPv6Type is the ICMP type field described in RFC 4443 and friends. +// ICMPv6Type is the ICMP type field described in RFC 4443. type ICMPv6Type byte +// Values for use in the Type field of ICMPv6 packet from RFC 4433. const ( ICMPv6DstUnreachable ICMPv6Type = 1 ICMPv6PacketTooBig ICMPv6Type = 2 @@ -109,7 +122,18 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// ICMPv6Code is the ICMP code field described in RFC 4443. +// IsErrorType returns true if the receiver is an ICMP error type. +func (typ ICMPv6Type) IsErrorType() bool { + // Per RFC 4443 section 2.1: + // ICMPv6 messages are grouped into two classes: error messages and + // informational messages. Error messages are identified as such by a + // zero in the high-order bit of their message Type field values. Thus, + // error messages have message types from 0 to 127; informational + // messages have message types from 128 to 255. + return typ&0x80 == 0 +} + +// ICMPv6Code is the ICMP Code field described in RFC 4443. type ICMPv6Code byte // ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443 @@ -153,6 +177,11 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } +// SetTypeSpecific sets the full 32 bit type specific data field. +func (b ICMPv6) SetTypeSpecific(val uint32) { + binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) +} + // Checksum is the ICMP checksum field. func (b ICMPv6) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e8816c3f4..b07d9991d 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -80,7 +80,8 @@ type IPv4Fields struct { type IPv4 []byte const ( - // IPv4MinimumSize is the minimum size of a valid IPv4 packet. + // IPv4MinimumSize is the minimum size of a valid IPv4 packet; + // i.e. a packet header with no options. IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given @@ -327,7 +328,7 @@ func IsV4MulticastAddress(addr tcpip.Address) bool { } // IsV4LoopbackAddress determines if the provided address is an IPv4 loopback -// address (belongs to 127.0.0.1/8 subnet). +// address (belongs to 127.0.0.0/8 subnet). See RFC 1122 section 3.2.1.3. func IsV4LoopbackAddress(addr tcpip.Address) bool { if len(addr) != IPv4AddressSize { return false diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cb9225bd7..81e286e80 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -238,6 +238,12 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return 0, false, parse.ARP(pkt) } +// ReturnError implements stack.TransportProtocol.ReturnError. +func (*protocol) ReturnError(*stack.Route, tcpip.ICMPReason, *stack.PacketBuffer) *tcpip.Error { + // In ARP, there is no such response so do nothing. + return nil +} + // NewProtocol returns an ARP network protocol. func NewProtocol() stack.NetworkProtocol { return &protocol{} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index b5659a36b..5fe73315f 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -105,7 +106,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). localAddr := r.LocalAddress - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { localAddr = "" } @@ -131,7 +132,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: dataVV, }) - + // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header + // information we will have to change this code to handle the ICMP header + // no longer being in the data buffer. + replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber // Send out the reply packet. sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ @@ -193,3 +197,175 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Invalid.Increment() } } + +// ======= ICMP Error packet generation ========= + +// ReturnError implements stack.TransportProtocol.ReturnError. +func (p *protocol) ReturnError(r *stack.Route, reason tcpip.ICMPReason, pkt *stack.PacketBuffer) *tcpip.Error { + switch reason.(type) { + case *tcpip.ICMPReasonPortUnreachable: + return returnError(r, &icmpReasonPortUnreachable{}, pkt) + default: + return tcpip.ErrNotSupported + } +} + +// icmpReason is a marker interface for IPv4 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv4 and sends it back to the remote device that sent +// the problematic packet. It incorporates as much of that packet as +// possible as well as any error metadata as is available. returnError +// expects pkt to hold a valid IPv4 packet as per the wire format. +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + sent := r.Stats().ICMP.V4PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + // We check we are responding only when we are allowed to. + // See RFC 1812 section 4.3.2.7 (shown below). + // + // ========= + // 4.3.2.7 When Not to Send ICMP Errors + // + // An ICMP error message MUST NOT be sent as the result of receiving: + // + // o An ICMP error message, or + // + // o A packet which fails the IP header validation tests described in + // Section [5.2.2] (except where that section specifically permits + // the sending of an ICMP error message), or + // + // o A packet destined to an IP broadcast or IP multicast address, or + // + // o A packet sent as a Link Layer broadcast or multicast, or + // + // o Any fragment of a datagram other then the first fragment (i.e., a + // packet for which the fragment offset in the IP header is nonzero). + // + // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in + // response to a non-initial fragment, but it currently can not happen. + + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + return nil + } + + networkHeader := pkt.NetworkHeader().View() + transportHeader := pkt.TransportHeader().View() + + // Don't respond to icmp error packets. + if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + // TODO(gvisor.dev/issue/3810): + // Unfortunately the current stack pretty much always has ICMPv4 headers + // in the Data section of the packet but there is no guarantee that is the + // case. If this is the case grab the header to make it like all other + // packet types. When this is cleaned up the Consume should be removed. + if transportHeader.IsEmpty() { + var ok bool + transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize) + if !ok { + return nil + } + } else if transportHeader.Size() < header.ICMPv4MinimumSize { + return nil + } + // We need to decide to explicitly name the packets we can respond to or + // the ones we can not respond to. The decision is somewhat arbitrary and + // if problems arise this could be reversed. It was judged less of a breach + // of protocol to not respond to unknown non-error packets than to respond + // to unknown error packets so we take the first approach. + switch header.ICMPv4(transportHeader).Type() { + case + header.ICMPv4EchoReply, + header.ICMPv4Echo, + header.ICMPv4Timestamp, + header.ICMPv4TimestampReply, + header.ICMPv4InfoRequest, + header.ICMPv4InfoReply: + default: + // Assume any type we don't know about may be an error type. + return nil + } + } else if transportHeader.IsEmpty() { + return nil + } + + // Now work out how much of the triggering packet we should return. + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes. + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 and RFC 792 where it mentioned that at + // least 8 bytes of the payload must be included. Today linux and other + // systems implement the RFC 1812 definition and not the original + // requirement. We treat 8 bytes as the minimum but will try send more. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + + if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { + return nil + } + + payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + + // The buffers used by pkt may be used elsewhere in the system. + // For example, an AF_RAW or AF_PACKET socket may use what the transport + // protocol considers an unreachable destination. Thus we deep copy pkt to + // prevent multiple ownership and SR errors. The new copy is a vectorized + // view with the entire incoming IP packet reassembled and truncated as + // required. This is now the payload of the new ICMP packet and no longer + // considered a packet in its own right. + newHeader := append(buffer.View(nil), networkHeader...) + newHeader = append(newHeader, transportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) + payload.CapLength(payloadLen) + + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + counter := sent.DstUnreachable + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + + if err := r.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + icmpPkt, + ); err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b14b356d6..135444222 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -455,6 +455,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { + // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport + // headers, the setting of the transport number here should be + // unnecessary and removed. + pkt.TransportProtocolNumber = p e.handleICMP(r, pkt) return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index b14bc98e8..86187aba8 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,6 +17,7 @@ package ipv4_test import ( "bytes" "encoding/hex" + "math" "testing" "github.com/google/go-cmp/cmp" @@ -160,47 +161,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } } -type testRoute struct { - stack.Route - - linkEP *testutil.TestEndpoint -} - -func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute { - // Make the packet and write it. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - }) - testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors) - s.CreateNIC(1, testEP) - const ( - src = "\x10\x00\x00\x01" - dst = "\x10\x00\x00\x02" - ) - s.AddAddress(1, ipv4.ProtocolNumber, src) - { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}) - } - r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) - } - t.Cleanup(func() { - testEP.Close() - }) - return testRoute{ - Route: r, - linkEP: testEP, - } -} - func TestFragmentation(t *testing.T) { var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { @@ -228,7 +188,8 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil) + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ @@ -237,16 +198,16 @@ func TestFragmentation(t *testing.T) { TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Errorf("err got %v, want %v", err, nil) + t.Errorf("got err = %s, want = nil", err) } - if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want { - t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want) + if got := len(ep.WrittenPackets); got != ft.expectedFrags { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) } - if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } - compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu) + compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } } @@ -259,35 +220,30 @@ func TestFragmentationErrors(t *testing.T) { mtu uint32 transportHeaderLength int payloadViewsSizes []int - packetCollectorErrors []*tcpip.Error + err *tcpip.Error + allowPackets int }{ - {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0}, + {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0}, + {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1}, + {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0}, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, }, pkt) - for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { - if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { - t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) - } - } - // We only need to check that last error because all the ones before are - // nil. - if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { - t.Errorf("err got %v, want %v", got, want) + if err != ft.err { + t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) } - if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { - t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } }) } @@ -1052,7 +1008,7 @@ func TestWriteStats(t *testing.T) { tests := []struct { name string setup func(*testing.T, *stack.Stack) - linkEP func() stack.LinkEndpoint + allowPackets int expectSent int expectDropped int expectWritten int @@ -1061,7 +1017,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets, expectDropped: 0, expectWritten: nPackets, @@ -1069,7 +1025,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all with error", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} }, + allowPackets: nPackets - 1, expectSent: nPackets - 1, expectDropped: 0, expectWritten: nPackets - 1, @@ -1086,10 +1042,10 @@ func TestWriteStats(t *testing.T) { ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = stack.DropTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) + t.Fatalf("failed to replace table: %s", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: 0, expectDropped: nPackets, expectWritten: nPackets, @@ -1111,10 +1067,10 @@ func TestWriteStats(t *testing.T) { // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { - t.Fatalf("failed to replace table: %v", err) + t.Fatalf("failed to replace table: %s", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets - 1, expectDropped: 1, expectWritten: nPackets, @@ -1150,7 +1106,8 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - rt := buildRoute(t, nil, test.linkEP()) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -1181,101 +1138,37 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, }) - s.CreateNIC(1, linkEP) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" ) - s.AddAddress(1, ipv4.ProtocolNumber, src) + if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + } { subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) if err != nil { - t.Fatal(err) + t.Fatalf("NewSubnet(_, _) failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, NIC: 1, }}) } - rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) } return rt } -// limitedEP is a link endpoint that writes up to a certain number of packets -// before returning errors. -type limitedEP struct { - limit int -} - -// MTU implements LinkEndpoint.MTU. -func (*limitedEP) MTU() uint32 { - // Give an MTU that won't cause fragmentation for IPv4+UDP. - return header.IPv4MinimumSize + header.UDPMinimumSize -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*limitedEP) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - if ep.limit == 0 { - return 0, tcpip.ErrInvalidEndpointState - } - nWritten := ep.limit - if nWritten > pkts.Len() { - nWritten = pkts.Len() - } - ep.limit -= nWritten - return nWritten, nil -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// Attach implements LinkEndpoint.Attach. -func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*limitedEP) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*limitedEP) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - // limitedMatcher is an iptables matcher that matches after a certain number of // packets are checked against it. type limitedMatcher struct { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index cd5fe3ea8..8bd8f5c52 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -35,6 +35,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2b83c421e..072c8ccd7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -318,6 +318,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), }) packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(solicited) @@ -438,6 +439,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme Data: pkt.Data, }) packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) @@ -477,7 +479,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme stack := r.Stack() // Is the networking stack operating as a router? - if !stack.Forwarding() { + if !stack.Forwarding(ProtocolNumber) { // ... No, silently drop the packet. received.RouterOnlyPacketsDroppedByHost.Increment() return @@ -637,6 +639,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, }) icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr.SetType(header.ICMPv6NeighborSolicit) copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr @@ -665,3 +668,123 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } return tcpip.LinkAddress([]byte(nil)), false } + +// ======= ICMP Error packet generation ========= + +// ReturnError implements stack.TransportProtocol.ReturnError. +func (p *protocol) ReturnError(r *stack.Route, reason tcpip.ICMPReason, pkt *stack.PacketBuffer) *tcpip.Error { + switch reason.(type) { + case *tcpip.ICMPReasonPortUnreachable: + return returnError(r, &icmpReasonPortUnreachable{}, pkt) + default: + return tcpip.ErrNotSupported + } +} + +// icmpReason is a marker interface for IPv6 specific ICMP errors. +type icmpReason interface { + isICMPReason() +} + +// icmpReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type icmpReasonPortUnreachable struct{} + +func (*icmpReasonPortUnreachable) isICMPReason() {} + +// returnError takes an error descriptor and generates the appropriate ICMP +// error packet for IPv6 and sends it. +func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + stats := r.Stats().ICMP + sent := stats.V6PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + + // Only send ICMP error if the address is not a multicast v6 + // address and the source is not the unspecified address. + // + // TODO(b/164522993) There are exceptions to this rule. + // See: point e.3) RFC 4443 section-2.4 + // + // (e) An ICMPv6 error message MUST NOT be originated as a result of + // receiving the following: + // + // (e.1) An ICMPv6 error message. + // + // (e.2) An ICMPv6 redirect message [IPv6-DISC]. + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + // + if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any { + return nil + } + + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + + if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { + // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. + // Unfortunately at this time ICMP Packets do not have a transport + // header separated out. It is in the Data part so we need to + // separate it out now. We will just pretend it is a minimal length + // ICMP packet as we don't really care if any later bits of a + // larger ICMP packet are in the header view or in the Data view. + transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize) + if !ok { + return nil + } + typ := header.ICMPv6(transport).Type() + if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { + return nil + } + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + available := int(mtu) - headerLen + if available < header.IPv6MinimumSize { + return nil + } + payloadLen := network.Size() + transport.Size() + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload.CapLength(payloadLen) + + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, + }) + newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + + icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) + counter := sent.DstUnreachable + err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) + if err != nil { + sent.Dropped.Increment() + return err + } + counter.Increment() + return nil +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8112ed051..0f50bfb8e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -728,7 +728,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }) if isRouter { // Enabling forwarding makes the stack act as a router. - s.SetForwarding(true) + s.SetForwarding(ProtocolNumber, true) } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index ee64d92d8..5b1cca180 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -348,7 +348,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { it, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() return } if done { @@ -476,6 +476,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { + pkt.TransportProtocolNumber = p e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 9eea1de8d..7d138dadb 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -1715,7 +1716,7 @@ func TestWriteStats(t *testing.T) { tests := []struct { name string setup func(*testing.T, *stack.Stack) - linkEP func() stack.LinkEndpoint + allowPackets int expectSent int expectDropped int expectWritten int @@ -1724,7 +1725,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets, expectDropped: 0, expectWritten: nPackets, @@ -1732,7 +1733,7 @@ func TestWriteStats(t *testing.T) { name: "Accept all with error", // No setup needed, tables accept everything by default. setup: func(*testing.T, *stack.Stack) {}, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} }, + allowPackets: nPackets - 1, expectSent: nPackets - 1, expectDropped: 0, expectWritten: nPackets - 1, @@ -1752,7 +1753,7 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: 0, expectDropped: nPackets, expectWritten: nPackets, @@ -1777,7 +1778,7 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + allowPackets: math.MaxInt32, expectSent: nPackets - 1, expectDropped: 1, expectWritten: nPackets, @@ -1812,7 +1813,8 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - rt := buildRoute(t, nil, test.linkEP()) + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -1843,100 +1845,37 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, }) - s.CreateNIC(1, linkEP) + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC(1, _) failed: %s", err) + } const ( src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) - s.AddAddress(1, ProtocolNumber, src) + if err := s.AddAddress(1, ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err) + } { subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) if err != nil { - t.Fatal(err) + t.Fatalf("NewSubnet(_, _) failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, NIC: 1, }}) } - rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */) + rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("s.FindRoute got %v, want %v", err, nil) + t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err) } return rt } -// limitedEP is a link endpoint that writes up to a certain number of packets -// before returning errors. -type limitedEP struct { - limit int -} - -// MTU implements LinkEndpoint.MTU. -func (*limitedEP) MTU() uint32 { - return header.IPv6MinimumMTU -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } - -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*limitedEP) MaxHeaderLength() uint16 { return 0 } - -// LinkAddress implements LinkEndpoint.LinkAddress. -func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } - -// WritePacket implements LinkEndpoint.WritePacket. -func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// WritePackets implements LinkEndpoint.WritePackets. -func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - if ep.limit == 0 { - return 0, tcpip.ErrInvalidEndpointState - } - nWritten := ep.limit - if nWritten > pkts.Len() { - nWritten = pkts.Len() - } - ep.limit -= nWritten - return nWritten, nil -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - if ep.limit == 0 { - return tcpip.ErrInvalidEndpointState - } - ep.limit-- - return nil -} - -// Attach implements LinkEndpoint.Attach. -func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} - -// IsAttached implements LinkEndpoint.IsAttached. -func (*limitedEP) IsAttached() bool { return false } - -// Wait implements LinkEndpoint.Wait. -func (*limitedEP) Wait() {} - -// ARPHardwareType implements LinkEndpoint.ARPHardwareType. -func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } - -// AddHeader implements LinkEndpoint.AddHeader. -func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { -} - // limitedMatcher is an iptables matcher that matches after a certain number of // packets are checked against it. type limitedMatcher struct { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 480c495fa..7434df4a1 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -958,7 +958,7 @@ func TestNDPValidation(t *testing.T) { if isRouter { // Enabling forwarding makes the stack act as a router. - s.SetForwarding(true) + s.SetForwarding(ProtocolNumber, true) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index e218563d0..c9e57dc0d 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -7,11 +7,14 @@ go_library( srcs = [ "testutil.go", ], - visibility = ["//pkg/tcpip/network/ipv4:__pkg__"], + visibility = [ + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", - "//pkg/tcpip/link/channel", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index bf5ce74be..7cc52985e 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -22,48 +22,100 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// TestEndpoint is an endpoint used for testing, it stores packets written to it -// and can mock errors. -type TestEndpoint struct { - *channel.Endpoint - - // WrittenPackets is where we store packets written via WritePacket(). +// MockLinkEndpoint is an endpoint used for testing, it stores packets written +// to it and can mock errors. +type MockLinkEndpoint struct { + // WrittenPackets is where packets written to the endpoint are stored. WrittenPackets []*stack.PacketBuffer - packetCollectorErrors []*tcpip.Error + mtu uint32 + err *tcpip.Error + allowPackets int } -// NewTestEndpoint creates a new TestEndpoint endpoint. +// NewMockLinkEndpoint creates a new MockLinkEndpoint. // -// packetCollectorErrors can be used to set error values and each call to -// WritePacket will remove the first one from the slice and return it until -// the slice is empty - at that point it will return nil every time. -func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint { - return &TestEndpoint{ - Endpoint: ep, - WrittenPackets: make([]*stack.PacketBuffer, 0), - packetCollectorErrors: packetCollectorErrors, +// err is the error that will be returned once allowPackets packets are written +// to the endpoint. +func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { + return &MockLinkEndpoint{ + mtu: mtu, + err: err, + allowPackets: allowPackets, + } +} + +// MTU implements LinkEndpoint.MTU. +func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *MockLinkEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err } + ep.allowPackets-- + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil } -// WritePacket stores outbound packets and may return an error if one was -// injected. -func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.WrittenPackets = append(e.WrittenPackets, pkt) +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var n int - if len(e.packetCollectorErrors) > 0 { - nextError := e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - return nextError + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + return n, err + } + n++ } + return n, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + if ep.allowPackets == 0 { + return ep.err + } + ep.allowPackets-- + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + ep.WrittenPackets = append(ep.WrittenPackets, pkt) + return nil } +// Attach implements LinkEndpoint.Attach. +func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*MockLinkEndpoint) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*MockLinkEndpoint) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + // MakeRandPkt generates a randomized packet. transportHeaderLength indicates // how many random bytes will be copied in the Transport Header. // extraHeaderReserveLength indicates how much extra space will be reserved for diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 900938dd1..7f1d79115 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -138,7 +138,6 @@ go_test( name = "stack_test", size = "small", srcs = [ - "fake_time_test.go", "forwarder_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", @@ -152,8 +151,8 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", - "@com_github_dpjacques_clockwork//:go_default_library", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 54759091a..e30927821 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -145,6 +145,10 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } +func (*fwdTestNetworkProtocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error { + return nil +} + func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { return &fwdTestNetworkEndpoint{ nicID: nicID, @@ -316,7 +320,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } // Enable forwarding. - s.SetForwarding(true) + s.SetForwarding(proto.Number(), true) // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index b0873d1af..97ca00d16 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -817,7 +817,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a stack-wide configuration, so we check // stack's forwarding flag to determine if the NIC is a routing // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) { return } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 67dc5364f..5e43a9b0b 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1120,7 +1120,7 @@ func TestNoRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1365,7 +1365,7 @@ func TestNoPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1723,7 +1723,7 @@ func TestNoAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4640,7 +4640,7 @@ func TestCleanupNDPState(t *testing.T) { name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, keepAutoGenLinkLocal: true, maxAutoGenAddrEvents: 4, @@ -5286,11 +5286,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(false) + s.SetForwarding(ipv6.ProtocolNumber, false) }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, }, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b4fa69e3e..a0b7da5cd 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -30,6 +30,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) const ( @@ -239,7 +240,7 @@ type entryEvent struct { func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) if got, want := neigh.config(), c; got != want { @@ -257,7 +258,7 @@ func TestNeighborCacheGetConfig(t *testing.T) { func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) c.MinRandomFactor = 1 @@ -279,7 +280,7 @@ func TestNeighborCacheSetConfig(t *testing.T) { func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -298,7 +299,7 @@ func TestNeighborCacheEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -339,7 +340,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -358,7 +359,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -409,7 +410,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } type testContext struct { - clock *fakeClock + clock *faketime.ManualClock neigh *neighborCache store *testEntryStore linkRes *testNeighborResolver @@ -418,7 +419,7 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -454,7 +455,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -567,7 +568,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -803,7 +804,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -876,7 +877,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -902,7 +903,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if doneCh == nil { t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -944,7 +945,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -974,7 +975,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { // Remove the waker before the neighbor cache has the opportunity to send a // notification. neigh.removeWaker(entry.Addr, &w) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -1073,7 +1074,7 @@ func TestNeighborCacheClear(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1092,7 +1093,7 @@ func TestNeighborCacheClear(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -1188,7 +1189,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1249,7 +1250,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { config.MaxRandomFactor = 1 nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1277,7 +1278,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1325,7 +1326,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1412,7 +1413,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1440,7 +1441,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Wait() // Process all the requests for a single entry concurrently - clock.advance(typicalLatency) + clock.Advance(typicalLatency) } // All goroutines add in the same order and add more values than can fit in @@ -1472,7 +1473,7 @@ func TestNeighborCacheReplace(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1491,7 +1492,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1541,7 +1542,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(config.DelayFirstProbeTime + typicalLatency) + clock.Advance(config.DelayFirstProbeTime + typicalLatency) select { case <-doneCh: default: @@ -1552,7 +1553,7 @@ func TestNeighborCacheReplace(t *testing.T) { // Verify the entry's new link address { e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } @@ -1572,7 +1573,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() @@ -1595,7 +1596,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) if err != nil { t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) @@ -1618,7 +1619,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1636,7 +1637,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config := DefaultNUDConfigurations() config.RetransmitTimer = time.Millisecond // small enough to cause timeout - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1654,7 +1655,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1664,7 +1665,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { // resolved immediately and don't send resolution requests. func TestNeighborCacheStaticResolution(t *testing.T) { config := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 0068cacb8..213646160 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -73,8 +73,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - protocol tcpip.NetworkProtocolNumber + nic *NIC // linkRes provides the functionality to send reachability probes, used in // Neighbor Unreachability Detection. diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index b769fb2fa..e530ec7ea 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -27,6 +27,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) const ( @@ -221,8 +222,8 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return entryTestNetNumber } -func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { - clock := newFakeClock() +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { + clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ id: entryTestNICID, @@ -267,7 +268,7 @@ func TestEntryInitiallyUnknown(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // No probes should have been sent. linkRes.mu.Lock() @@ -300,7 +301,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { } e.mu.Unlock() - clock.advance(time.Hour) + clock.Advance(time.Hour) // No probes should have been sent. linkRes.mu.Lock() @@ -410,7 +411,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { updatedAt := e.neigh.UpdatedAt e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should remain the same during address resolution. wantProbes := []entryTestProbeInfo{ @@ -439,7 +440,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should change after failing address resolution. Timing out after // sending the last probe transitions the entry to Failed. @@ -459,7 +460,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } } - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) wantEvents := []testEntryEventInfo{ { @@ -748,7 +749,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The Incomplete-to-Incomplete state transition is tested here by @@ -983,7 +984,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1612,7 +1613,7 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1706,7 +1707,7 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1989,7 +1990,7 @@ func TestEntryDelayToProbe(t *testing.T) { } e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2069,7 +2070,7 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2166,7 +2167,7 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2267,7 +2268,7 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2364,7 +2365,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // Probe caused by the Delay-to-Probe transition @@ -2398,7 +2399,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2463,7 +2464,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2503,7 +2504,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2575,7 +2576,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2612,7 +2613,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2682,7 +2683,7 @@ func TestEntryProbeToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2787,7 +2788,7 @@ func TestEntryFailedGetsDeleted(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 204bfc433..06d70dd1c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -337,7 +337,7 @@ func (n *NIC) enable() *tcpip.Error { // does. That is, routers do not learn from RAs (e.g. on-link prefixes // and default routers). Therefore, soliciting RAs from other routers on // a link is unnecessary for routers. - if !n.stack.forwarding { + if !n.stack.Forwarding(header.IPv6ProtocolNumber) { n.mu.ndp.startSolicitingRouters() } @@ -1242,9 +1242,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp local = n.linkEP.LinkAddress() } - // Are any packet sockets listening for this network protocol? + // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Add any other packet sockets that maybe listening for all protocols. + // Add any other packet type sockets that may be listening for all protocols. packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { @@ -1265,6 +1265,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } if hasTransportHdr { + pkt.TransportProtocolNumber = transProtoNum // Parse the transport header if present. if state, ok := n.stack.transportProtocols[transProtoNum]; ok { state.proto.Parse(pkt) @@ -1303,7 +1304,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() @@ -1330,6 +1331,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) @@ -1452,10 +1454,28 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } } - // We could not find an appropriate destination for this packet, so - // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { + // We could not find an appropriate destination for this packet so + // give the protocol specific error handler a chance to handle it. + // If it doesn't handle it then we should do so. + switch transProto.HandleUnknownDestinationPacket(r, id, pkt) { + case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() + case UnknownDestinationPacketUnhandled: + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + np, ok := n.stack.networkProtocols[r.NetProto] + if !ok { + // For this to happen stack.makeRoute() must have been called with the + // incorrect protocol number. Since we have successfully completed + // network layer processing this should be impossible. + panic(fmt.Sprintf("expected stack to have a NetworkProtocol for proto = %d", r.NetProto)) + } + + _ = np.ReturnError(r, &tcpip.ICMPReasonPortUnreachable{}, pkt) + case UnknownDestinationPacketHandled: } } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index dd6474297..ef6e63b3e 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -221,6 +221,11 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } +// ReturnError implements NetworkProtocol.ReturnError. +func (*testIPv6Protocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error { + return nil +} + var _ LinkAddressResolver = (*testIPv6Protocol)(nil) // LinkAddressProtocol implements LinkAddressResolver. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1932aaeb7..a7d9d59fa 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -80,11 +80,17 @@ type PacketBuffer struct { // data are held in the same underlying buffer storage. header buffer.Prependable - // NetworkProtocolNumber is only valid when NetworkHeader is set. + // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() + // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol // numbers in registration APIs that take a PacketBuffer. NetworkProtocolNumber tcpip.NetworkProtocolNumber + // TransportProtocol is only valid if it is non zero. + // TODO(gvisor.dev/issue/3810): This and the network protocol number should + // be moved into the headerinfo. This should resolve the validity issue. + TransportProtocolNumber tcpip.TransportProtocolNumber + // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 @@ -234,16 +240,17 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, } return newPk } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 4fa86a3ac..77640cd8a 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -125,6 +125,26 @@ type PacketEndpoint interface { HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } +// UnknownDestinationPacketDisposition enumerates the possible return vaues from +// HandleUnknownDestinationPacket(). +type UnknownDestinationPacketDisposition int + +const ( + // UnknownDestinationPacketMalformed denotes that the packet was malformed + // and no further processing should be attempted other than updating + // statistics. + UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota + + // UnknownDestinationPacketUnhandled tells the caller that the packet was + // well formed but that the issue was not handled and the stack should take + // the default action. + UnknownDestinationPacketUnhandled + + // UnknownDestinationPacketHandled tells the caller that it should do + // no further processing. + UnknownDestinationPacketHandled +) + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { @@ -147,14 +167,12 @@ type TransportProtocol interface { ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this - // protocol but that don't match any existing endpoint. For example, - // it is targeted at a port that have no listeners. + // protocol that don't match any existing endpoint. For example, + // it is targeted at a port that has no listeners. // - // The return value indicates whether the packet was well-formed (for - // stats purposes only). - // - // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // the issue. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -324,6 +342,19 @@ type NetworkProtocol interface { // does not encapsulate anything). // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) + + // ReturnError attempts to send a suitable error message to the sender + // of a received packet. + // - pkt holds the problematic packet. + // - reason indicates what the reason for wanting a message is. + // - route is the routing information for the received packet + // ReturnError returns an error if the send failed and nil on success. + // Note that deciding to deliberately send no message is a success. + // + // TODO(gvisor.dev/issues/3871): This method should be removed or simplified + // after all (or all but one) of the ICMP error dispatch occurs through the + // protocol specific modules. May become SendPortNotFound(r, pkt). + ReturnError(r *Route, reason tcpip.ICMPReason, pkt *PacketBuffer) *tcpip.Error } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6a683545d..e7b7e95d4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -144,10 +144,7 @@ type TCPReceiverState struct { // PendingBufUsed is the number of bytes pending in the receive // queue. - PendingBufUsed seqnum.Size - - // PendingBufSize is the size of the socket receive buffer. - PendingBufSize seqnum.Size + PendingBufUsed int } // TCPSenderState holds a copy of the internal state of the sender for @@ -405,6 +402,13 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // forwarding contains the whether packet forwarding is enabled or not for + // different network protocols. + forwarding struct { + sync.RWMutex + protocols map[tcpip.NetworkProtocolNumber]bool + } + // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. rawFactory RawFactory @@ -415,9 +419,8 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -749,6 +752,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } + s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool) // Add specified network protocols. for _, netProto := range opts.NetworkProtocols { @@ -866,46 +870,42 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -// -// When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. -// When forwarding becomes disabled and if IPv6 is enabled, NDP Router -// Solicitations will be stopped. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - defer s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { + s.forwarding.Lock() + defer s.forwarding.Unlock() - // If forwarding status didn't change, do nothing further. - if s.forwarding == enable { + // If this stack does not support the protocol, do nothing. + if _, ok := s.networkProtocols[protocol]; !ok { return } - s.forwarding = enable - - // If this stack does not support IPv6, do nothing further. - if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { + // If the forwarding value for this protocol hasn't changed then do + // nothing. + if forwarding := s.forwarding.protocols[protocol]; forwarding == enable { return } - if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } - } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() + s.forwarding.protocols[protocol] = enable + + if protocol == header.IPv6ProtocolNumber { + if enable { + for _, nic := range s.nics { + nic.becomeIPv6Router() + } + } else { + for _, nic := range s.nics { + nic.becomeIPv6Host() + } } } } -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding +// Forwarding returns if packet forwarding between NICs is enabled. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + s.forwarding.RLock() + defer s.forwarding.RUnlock() + return s.forwarding.protocols[protocol] } // SetRouteTable assigns the route table to be used by this stack. It diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 60b54c244..9ef6787c6 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -216,13 +216,18 @@ func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) } } -// Close implements TransportProtocol.Close. +// ReturnError implements NetworkProtocol.ReturnError +func (*fakeNetworkProtocol) ReturnError(*stack.Route, tcpip.ICMPReason, *stack.PacketBuffer) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. func (*fakeNetworkProtocol) Close() {} -// Wait implements TransportProtocol.Wait. +// Wait implements NetworkProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} -// Parse implements TransportProtocol.Parse. +// Parse implements NetworkProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { @@ -2091,7 +2096,7 @@ func TestNICForwarding(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID1, ep1); err != nil { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ef3457e32..cbb34d224 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -287,8 +287,8 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { @@ -549,7 +549,7 @@ func TestTransportForwarding(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 464608dee..fa73cfa47 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1987,3 +1987,14 @@ func DeleteDanglingEndpoint(e Endpoint) { // AsyncLoading is the global barrier for asynchronous endpoint loading // activities. var AsyncLoading sync.WaitGroup + +// ICMPReason is a marker interface for network protocol agnostic ICMP errors. +type ICMPReason interface { + isICMP() +} + +// ICMPReasonPortUnreachable is an error where the transport protocol has no +// listener and no alternative means to inform the sender. +type ICMPReasonPortUnreachable struct{} + +func (*ICMPReasonPortUnreachable) isICMP() {} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 31116309e..41eb0ca44 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -446,6 +446,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -478,6 +479,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index bb11e4e83..941c3c08d 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,8 +104,8 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } // SetOption implements stack.TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4778e7b1c..518449602 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -94,6 +94,7 @@ go_test( shard_count = 10, deps = [ ":tcp", + "//pkg/rand", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 09d53d158..6891fd245 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -747,6 +747,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) + pkt.TransportProtocolNumber = header.TCPProtocolNumber tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -897,7 +898,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -1002,9 +1003,8 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - rcvBufSize := seqnum.Size(e.receiveBufferSize()) e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. @@ -1135,12 +1135,11 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { } cont, err := e.handleSegment(s) + s.decRef() if err != nil { - s.decRef() return err } if !cont { - s.decRef() return nil } } @@ -1220,6 +1219,12 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } + // Increase counter if after processing the segment we would potentially + // advertise a zero window. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. @@ -1232,7 +1237,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // or a notification from the protocolMainLoop (caller goroutine). // This means that with this return, the segment dequeue below can // never occur on a closed endpoint. - s.decRef() return false, nil } @@ -1424,10 +1428,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.rcv.nonZeroWindow() } - if n¬ifyReceiveWindowChanged != 0 { - e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) - } - if n¬ifyMTUChanged != 0 { e.sndBufMu.Lock() count := e.packetTooBigCount diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 94207c141..560b4904c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -78,8 +78,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv4(t, c.GetPacket(), ackCheckers...) @@ -185,8 +185,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network ackCheckers := append(checkers, checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), )) checker.IPv6(t, c.GetV6Packet(), ackCheckers...) @@ -283,7 +283,7 @@ func TestV4RefuseOnV6Only(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -319,7 +319,7 @@ func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) } @@ -352,7 +352,7 @@ func testV4Accept(t *testing.T, c *context.Context) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) @@ -492,7 +492,7 @@ func TestV6AcceptOnV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1), + checker.TCPAckNum(uint32(irs)+1), ), ) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 120483838..87db13720 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,6 +63,17 @@ const ( StateClosing ) +const ( + // rcvAdvWndScale is used to split the available socket buffer into + // application buffer and the window to be advertised to the peer. This is + // currently hard coded to split the available space equally. + rcvAdvWndScale = 1 + + // SegOverheadFactor is used to multiply the value provided by the + // user on a SetSockOpt for setting the socket send/receive buffer sizes. + SegOverheadFactor = 2 +) + // connected returns true when s is one of the states representing an // endpoint connected to a peer. func (s EndpointState) connected() bool { @@ -149,7 +160,6 @@ func (s EndpointState) String() string { // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota - notifyReceiveWindowChanged notifyClose notifyMTUChanged notifyDrain @@ -384,13 +394,26 @@ type endpoint struct { // to indicate to users that no more data is coming. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - rcvList segmentList `state:"wait"` - rcvClosed bool - rcvBufSize int + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + // rcvBufSize is the total size of the receive buffer. + rcvBufSize int + // rcvBufUsed is the actual number of payload bytes held in the receive buffer + // not counting any overheads of the segments itself. NOTE: This will always + // be strictly <= rcvMemUsed below. rcvBufUsed int rcvAutoParams rcvBufAutoTuneParams + // rcvMemUsed tracks the total amount of memory in use by received segments + // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // compute the window and the actual available buffer space. This is distinct + // from rcvBufUsed above which is the actual number of payload bytes held in + // the buffer not including any segment overheads. + // + // rcvMemUsed must be accessed atomically. + rcvMemUsed int32 + // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. mu sync.Mutex `state:"nosave"` @@ -891,7 +914,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.probe = p } - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) @@ -1129,10 +1152,16 @@ func (e *endpoint) cleanupLocked() { tcpip.DeleteDanglingEndpoint(e) } +// wndFromSpace returns the window that we can advertise based on the available +// receive buffer space. +func wndFromSpace(space int) int { + return space / (1 << rcvAdvWndScale) +} + // initialReceiveWindow returns the initial receive window to advertise in the // SYN/SYN-ACK. func (e *endpoint) initialReceiveWindow() int { - rcvWnd := e.receiveBufferAvailable() + rcvWnd := wndFromSpace(e.receiveBufferAvailable()) if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } @@ -1209,14 +1238,12 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = rcvWnd - availAfter := e.receiveBufferAvailableLocked() - mask := uint32(notifyReceiveWindowChanged) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } - e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -1293,18 +1320,22 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { v := views[s.viewToDeliver] s.viewToDeliver++ + var delta int if s.viewToDeliver >= len(views) { e.rcvList.Remove(s) + // We only free up receive buffer space when the segment is released as the + // segment is still holding on to the views even though some views have been + // read out to the user. + delta = s.segMemSize() s.decRef() } e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1481,11 +1512,11 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } // windowCrossedACKThresholdLocked checks if the receive window to be announced -// now would be under aMSS or under half receive buffer, whichever smaller. This -// is useful as a receive side silly window syndrome prevention mechanism. If -// window grows to reasonable value, we should send ACK to the sender to inform -// the rx space is now large. We also want ensure a series of small read()'s -// won't trigger a flood of spurious tiny ACK's. +// would be under aMSS or under the window derived from half receive buffer, +// whichever smaller. This is useful as a receive side silly window syndrome +// prevention mechanism. If window grows to reasonable value, we should send ACK +// to the sender to inform the rx space is now large. We also want ensure a +// series of small read()'s won't trigger a flood of spurious tiny ACK's. // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of @@ -1496,17 +1527,18 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := e.receiveBufferAvailableLocked() + newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 } - threshold := int(e.amss) - if threshold > e.rcvBufSize/2 { - threshold = e.rcvBufSize / 2 + // rcvBufFraction is the inverse of the fraction of receive buffer size that + // is used to decide if the available buffer space is now above it. + const rcvBufFraction = 2 + if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + threshold = wndThreshold } - switch { case oldAvail < threshold && newAvail >= threshold: return true, true @@ -1636,17 +1668,23 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Make sure the receive buffer size is within the min and max // allowed. var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) + } + + if v > rs.Max { + v = rs.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < rs.Min { v = rs.Min } - if v > rs.Max { - v = rs.Max - } + } else { + v = math.MaxInt32 } - mask := uint32(notifyReceiveWindowChanged) - e.LockUser() e.rcvListMu.Lock() @@ -1660,14 +1698,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { v = 1 << scale } - // Make sure 2*size doesn't overflow. - if v > math.MaxInt32/2 { - v = math.MaxInt32 / 2 - } - - availBefore := e.receiveBufferAvailableLocked() + availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvBufSize = v - availAfter := e.receiveBufferAvailableLocked() + availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) e.rcvAutoParams.disabled = true @@ -1675,24 +1708,31 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - mask |= notifyNonZeroReceiveWindow + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } e.rcvListMu.Unlock() e.UnlockUser() - e.notifyProtocolGoroutine(mask) case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. var ss tcpip.TCPSendBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err)) + } + + if v > ss.Max { + v = ss.Max + } + + if v < math.MaxInt32/SegOverheadFactor { + v *= SegOverheadFactor if v < ss.Min { v = ss.Min } - if v > ss.Max { - v = ss.Max - } + } else { + v = math.MaxInt32 } e.sndBufMu.Lock() @@ -2699,13 +2739,8 @@ func (e *endpoint) updateSndBufferUsage(v int) { func (e *endpoint) readyToRead(s *segment) { e.rcvListMu.Lock() if s != nil { + e.rcvBufUsed += s.payloadSize() s.incRef() - e.rcvBufUsed += s.data.Size() - // Increase counter if the receive window falls down below MSS - // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } e.rcvList.PushBack(s) } else { e.rcvClosed = true @@ -2720,15 +2755,17 @@ func (e *endpoint) readyToRead(s *segment) { func (e *endpoint) receiveBufferAvailableLocked() int { // We may use more bytes than the buffer size when the receive buffer // shrinks. - if e.rcvBufUsed >= e.rcvBufSize { + memUsed := e.receiveMemUsed() + if memUsed >= e.rcvBufSize { return 0 } - return e.rcvBufSize - e.rcvBufUsed + return e.rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the -// receive buffer. +// receive buffer based on the actual memory used by all segments held in +// receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { e.rcvListMu.Lock() available := e.receiveBufferAvailableLocked() @@ -2736,14 +2773,35 @@ func (e *endpoint) receiveBufferAvailable() int { return available } +// receiveBufferUsed returns the amount of in-use receive buffer. +func (e *endpoint) receiveBufferUsed() int { + e.rcvListMu.Lock() + used := e.rcvBufUsed + e.rcvListMu.Unlock() + return used +} + +// receiveBufferSize returns the current size of the receive buffer. func (e *endpoint) receiveBufferSize() int { e.rcvListMu.Lock() size := e.rcvBufSize e.rcvListMu.Unlock() - return size } +// receiveMemUsed returns the total memory in use by segments held by this +// endpoint. +func (e *endpoint) receiveMemUsed() int { + return int(atomic.LoadInt32(&e.rcvMemUsed)) +} + +// updateReceiveMemUsed adds the provided delta to e.rcvMemUsed. +func (e *endpoint) updateReceiveMemUsed(delta int) { + atomic.AddInt32(&e.rcvMemUsed, int32(delta)) +} + +// maxReceiveBufferSize returns the stack wide maximum receive buffer size for +// an endpoint. func (e *endpoint) maxReceiveBufferSize() int { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { @@ -2894,7 +2952,6 @@ func (e *endpoint) completeState() stack.TCPEndpointState { RcvAcc: e.rcv.rcvAcc, RcvWndScale: e.rcv.rcvWndScale, PendingBufUsed: e.rcv.pendingBufUsed, - PendingBufSize: e.rcv.pendingBufSize, } // Copy sender state. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 41d0050f3..b25431467 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -44,7 +44,7 @@ func (e *endpoint) drainSegmentLocked() { // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets. - e.segmentQueue.setLimit(0) + e.segmentQueue.freeze() e.mu.Lock() defer e.mu.Unlock() @@ -178,7 +178,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.segmentQueue.thaw() epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 74a17af79..371067048 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -201,21 +201,20 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { - return false + return stack.UnknownDestinationPacketMalformed } - // There's nothing to do if this is already a reset packet. - if s.flagIsSet(header.TCPFlagRst) { - return true + if !s.flagIsSet(header.TCPFlagRst) { + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) } - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) - return true + return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index cfd43b5e3..4aafb4d22 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -47,22 +47,24 @@ type receiver struct { closed bool + // pendingRcvdSegments is bounded by the receive buffer size of the + // endpoint. pendingRcvdSegments segmentHeap - pendingBufUsed seqnum.Size - pendingBufSize seqnum.Size + // pendingBufUsed tracks the total number of bytes (including segment + // overhead) currently queued in pendingRcvdSegments. + pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` } -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ ep: ep, rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, rcvWndScale: rcvWndScale, - pendingBufSize: pendingBufSize, lastRcvdAckTime: time.Now(), } } @@ -85,15 +87,23 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - // Calculate the window size based on the available buffer space. - receiveBufferAvailable := r.ep.receiveBufferAvailable() - acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) - if r.rcvAcc.LessThan(acc) { - r.rcvAcc = acc + avail := wndFromSpace(r.ep.receiveBufferAvailable()) + acc := r.rcvNxt.Add(seqnum.Size(avail)) + newWnd := r.rcvNxt.Size(acc) + curWnd := r.rcvNxt.Size(r.rcvAcc) + + // Update rcvAcc only if new window is > previously advertised window. We + // should never shrink the acceptable sequence space once it has been + // advertised the peer. If we shrink the acceptable sequence space then we + // would end up dropping bytes that might already be in flight. + if newWnd > curWnd { + r.rcvAcc = r.rcvNxt.Add(newWnd) + } else { + newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. - r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + r.rcvWnd = newWnd return r.rcvNxt, r.rcvWnd >> r.rcvWndScale } @@ -195,7 +205,9 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of // truncated items, thus truncated items must be set to nil to avoid // memory leaks. @@ -384,10 +396,16 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { - // We only store the segment if it's within our buffer - // size limit. - if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += seqnum.Size(s.segMemSize()) + // We only store the segment if it's within our buffer size limit. + // + // Only use 75% of the receive buffer queue for out-of-order + // segments. This ensures that we always leave some space for the inorder + // segments to arrive allowing pending segments to be processed and + // delivered to the user. + if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { + r.ep.rcvListMu.Lock() + r.pendingBufUsed += s.segMemSize() + r.ep.rcvListMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -421,7 +439,9 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= seqnum.Size(s.segMemSize()) + r.ep.rcvListMu.Lock() + r.pendingBufUsed -= s.segMemSize() + r.ep.rcvListMu.Unlock() s.decRef() } return false, nil diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 94307d31a..13acaf753 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -15,6 +15,7 @@ package tcp import ( + "fmt" "sync/atomic" "time" @@ -24,6 +25,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// queueFlags are used to indicate which queue of an endpoint a particular segment +// belongs to. This is used to track memory accounting correctly. +type queueFlags uint8 + +const ( + recvQ queueFlags = 1 << iota + sendQ +) + // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. // segment is mostly immutable, the only field allowed to change is viewToDeliver. @@ -32,6 +42,8 @@ import ( type segment struct { segmentEntry refCnt int32 + ep *endpoint + qFlags queueFlags id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` @@ -100,6 +112,8 @@ func (s *segment) clone() *segment { rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, + ep: s.ep, + qFlags: s.qFlags, } t.data = s.data.Clone(t.views[:]) return t @@ -115,8 +129,34 @@ func (s *segment) flagsAreSet(flags uint8) bool { return s.flags&flags == flags } +// setOwner sets the owning endpoint for this segment. Its required +// to be called to ensure memory accounting for receive/send buffer +// queues is done properly. +func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) { + switch qFlags { + case recvQ: + ep.updateReceiveMemUsed(s.segMemSize()) + case sendQ: + // no memory account for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b", qFlags)) + } + s.ep = ep + s.qFlags = qFlags +} + func (s *segment) decRef() { if atomic.AddInt32(&s.refCnt, -1) == 0 { + if s.ep != nil { + switch s.qFlags { + case recvQ: + s.ep.updateReceiveMemUsed(-s.segMemSize()) + case sendQ: + // no memory accounting for sendQ yet. + default: + panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) + } + } s.route.Release() } } @@ -138,6 +178,11 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// payloadSize is the size of s.data. +func (s *segment) payloadSize() int { + return s.data.Size() +} + // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 48a257137..54545a1b1 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -22,16 +22,16 @@ import ( // // +stateify savable type segmentQueue struct { - mu sync.Mutex `state:"nosave"` - list segmentList `state:"wait"` - limit int - used int + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + ep *endpoint + frozen bool } // emptyLocked determines if the queue is empty. // Preconditions: q.mu must be held. func (q *segmentQueue) emptyLocked() bool { - return q.used == 0 + return q.list.Empty() } // empty determines if the queue is empty. @@ -43,14 +43,6 @@ func (q *segmentQueue) empty() bool { return r } -// setLimit updates the limit. No segments are immediately dropped in case the -// queue becomes full due to the new limit. -func (q *segmentQueue) setLimit(limit int) { - q.mu.Lock() - q.limit = limit - q.mu.Unlock() -} - // enqueue adds the given segment to the queue. // // Returns true when the segment is successfully added to the queue, in which @@ -58,15 +50,23 @@ func (q *segmentQueue) setLimit(limit int) { // false if the queue is full, in which case ownership is retained by the // caller. func (q *segmentQueue) enqueue(s *segment) bool { + // q.ep.receiveBufferParams() must be called without holding q.mu to + // avoid lock order inversion. + bufSz := q.ep.receiveBufferSize() + used := q.ep.receiveMemUsed() q.mu.Lock() - r := q.used < q.limit - if r { + // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue + // is currently full). + allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + + if allow { q.list.PushBack(s) - q.used++ + // Set the owner now that the endpoint owns the segment. + s.setOwner(q.ep, recvQ) } q.mu.Unlock() - return r + return allow } // dequeue removes and returns the next segment from queue, if one exists. @@ -77,9 +77,25 @@ func (q *segmentQueue) dequeue() *segment { s := q.list.Front() if s != nil { q.list.Remove(s) - q.used-- } q.mu.Unlock() return s } + +// freeze prevents any more segments from being added to the queue. i.e all +// future segmentQueue.enqueue will return false and not add the segment to the +// queue till the queue is unfroze with a corresponding segmentQueue.thaw call. +func (q *segmentQueue) freeze() { + q.mu.Lock() + q.frozen = true + q.mu.Unlock() +} + +// thaw unfreezes a previously frozen queue using segmentQueue.freeze() and +// allows new segments to be queued again. +func (q *segmentQueue) thaw() { + q.mu.Lock() + q.frozen = false + q.mu.Unlock() +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b1e5f1b24..8326736dc 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -240,6 +241,38 @@ func TestTCPResetsSentIncrement(t *testing.T) { } } +// TestTCPResetsSentNoICMP confirms that we don't get an ICMP +// DstUnreachable packet when we try send a packet which is not part +// of an active session. +func TestTCPResetsSentNoICMP(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + + // Send a SYN request for a closed port. This should elicit an RST + // but NOT an ICMPv4 DstUnreachable packet. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive whatever comes back. + b := c.GetPacket() + ipHdr := header.IPv4(b) + if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { + t.Errorf("unexpected protocol, got = %d, want = %d", got, want) + } + + // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. + sent := stats.ICMP.V4PacketsSent + if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { + t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) + } +} + // TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates // a RST if an ACK is received on the listening socket for which there is no // active handshake in progress and we are not using SYN cookies. @@ -317,8 +350,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -348,8 +381,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -447,8 +480,8 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -489,8 +522,8 @@ func TestConnectResetAfterClose(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence number // of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -529,8 +562,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -565,8 +598,8 @@ func TestCurrentConnectedIncrement(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -612,8 +645,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -633,8 +666,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -693,8 +726,8 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(0), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -745,8 +778,8 @@ func TestSimpleReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -998,7 +1031,7 @@ func TestSendRstOnListenerRxSynAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxSynAckV6(t *testing.T) { @@ -1026,7 +1059,7 @@ func TestSendRstOnListenerRxSynAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, @@ -1063,8 +1096,8 @@ func TestTCPAckBeforeAcceptV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } // TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, @@ -1101,8 +1134,8 @@ func TestTCPAckBeforeAcceptV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestSendRstOnListenerRxAckV4(t *testing.T) { @@ -1130,7 +1163,7 @@ func TestSendRstOnListenerRxAckV4(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } func TestSendRstOnListenerRxAckV6(t *testing.T) { @@ -1158,7 +1191,7 @@ func TestSendRstOnListenerRxAckV6(t *testing.T) { checker.IPv6(t, c.GetV6Packet(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst), - checker.SeqNum(200))) + checker.TCPSeqNum(200))) } // TestListenShutdown tests for the listening endpoint replying with RST @@ -1274,8 +1307,8 @@ func TestTOSV4(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1323,8 +1356,8 @@ func TestTrafficClassV6(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), checker.TOS(tos, 0), @@ -1514,8 +1547,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1565,8 +1598,8 @@ func TestOutOfOrderReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1576,8 +1609,8 @@ func TestOutOfOrderFlood(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Create a new connection with initial window size of 10. - c.CreateConnected(789, 30000, 10) + rcvBufSz := math.MaxUint16 + c.CreateConnected(789, 30000, rcvBufSz) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) @@ -1598,8 +1631,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1619,8 +1652,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1639,8 +1672,8 @@ func TestOutOfOrderFlood(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(793), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1681,8 +1714,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1696,7 +1729,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1750,8 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -1764,7 +1797,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), )) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { @@ -1783,7 +1816,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // RST is always generated with sndNxt which if the FIN // has been sent will be 1 higher than the sequence // number of the FIN itself. - checker.SeqNum(uint32(c.IRS)+2), + checker.TCPSeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -1829,7 +1862,8 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, 10) + const rcvBufSz = 10 + c.CreateConnected(789, 30000, rcvBufSz) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1840,8 +1874,13 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("Read failed: %s", err) } - // Fill up the window. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies + // the provided buffer value by tcp.SegOverheadFactor to calculate the actual + // receive buffer size. + data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) + for i := range data { + data[i] = byte(i % 255) + } c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -1862,10 +1901,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) @@ -1888,10 +1927,10 @@ func TestFullWindowReceive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+len(data))), checker.TCPFlags(header.TCPFlagAck), - checker.Window(10), + checker.TCPWindow(10), ), ) } @@ -1900,12 +1939,15 @@ func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Start off with a window size of 10, then shrink it to 5. - c.CreateConnected(789, 30000, 10) - - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + // Start off with a certain receive buffer then cut it in half and verify that + // the right edge of the window does not shrink. + // NOTE: Netstack doubles the value specified here. + rcvBufSize := 65536 + iss := seqnum.Value(789) + // Enable window scaling with a scale of zero from our end. + c.CreateConnectedWithRawOptions(iss, 30000, rcvBufSize, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -1914,14 +1956,15 @@ func TestNoWindowShrinking(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) } - - // Send 3 bytes, check that the peer acknowledges them. - data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - c.SendPacket(data[:3], &context.Headers{ + // Send a 1 byte payload so that we can record the current receive window. + // Send a payload of half the size of rcvBufSize. + seqNum := iss.Add(1) + payload := []byte{1} + c.SendPacket(payload, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 790, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -1933,46 +1976,93 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("Timed out waiting for data to arrive") } - // Check that data is acknowledged, and that window doesn't go to zero - // just yet because it was previously set to 10. It must go to 7 now. - checker.IPv4(t, c.GetPacket(), + // Read the 1 byte payload we just sent. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + if got, want := payload, v; !bytes.Equal(got, want) { + t.Fatalf("got data: %v, want: %v", got, want) + } + + seqNum = seqNum.Add(1) + // Verify that the ACK does not shrink the window. + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(793), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(7), ), ) + // Stash the initial window. + initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + initialLastAcceptableSeq := seqNum.Add(seqnum.Size(initialWnd)) + // Now shrink the receive buffer to half its original size. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) + } - // Send 7 more bytes, check that the window fills up. - c.SendPacket(data[3:], &context.Headers{ + data := generateRandomPayload(t, rcvBufSize) + // Send a payload of half the size of rcvBufSize. + c.SendPacket(data[:rcvBufSize/2], &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: 793, + SeqNum: seqNum, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") + // Verify that the ACK does not shrink the window. + pkt = c.GetPacket() + checker.IPv4(t, pkt, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale + newLastAcceptableSeq := seqNum.Add(seqnum.Size(newWnd)) + if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { + t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) } + // Send another payload of half the size of rcvBufSize. This should fill up the + // socket receive buffer and we should see a zero window. + c.SendPacket(data[rcvBufSize/2:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + seqNum = seqNum.Add(seqnum.Size(rcvBufSize / 2)) + checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(0), + checker.TCPWindow(0), ), ) + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + // Receive data and check it. - read := make([]byte, 0, 10) + read := make([]byte, 0, rcvBufSize) for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { @@ -1986,15 +2076,15 @@ func TestNoWindowShrinking(t *testing.T) { t.Fatalf("got data = %v, want = %v", read, data) } - // Check that we get an ACK for the newly non-zero window, which is the - // new size. + // Check that we get an ACK for the newly non-zero window, which is the new + // receive buffer size we set after the connection was established. checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(seqNum)), checker.TCPFlags(header.TCPFlagAck), - checker.Window(5), + checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), ), ) } @@ -2019,8 +2109,8 @@ func TestSimpleSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2061,8 +2151,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2083,8 +2173,8 @@ func TestZeroWindowSend(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2123,16 +2213,16 @@ func TestScaledWindowConnect(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2162,9 +2252,9 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2198,7 +2288,8 @@ func TestScaledWindowAccept(t *testing.T) { } // Do 3-way handshake. - c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 + c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) // Try to accept the connection. we, ch := waiter.NewChannelEntry(nil) @@ -2228,16 +2319,16 @@ func TestScaledWindowAccept(t *testing.T) { t.Fatalf("Write failed: %s", err) } - // Check that data is received, and that advertised window is 0xbfff, + // Check that data is received, and that advertised window is 0x5fff, // that is, that it is scaled. b := c.GetPacket() checker.IPv4(t, b, checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xbfff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0x5fff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2309,9 +2400,9 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(0xffff), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(0xffff), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2324,18 +2415,19 @@ func TestZeroScaledWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - // Set the window size such that a window scale of 4 will be used. - const wnd = 65535 * 10 - const ws = uint32(4) - c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ + // Set the buffer size such that a window scale of 5 will be used. + const bufSz = 65535 * 10 + const ws = uint32(5) + c.CreateConnectedWithRawOptions(789, 30000, bufSz, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) // Write chunks of 50000 bytes. - remain := wnd + remain := 0 sent := 0 data := make([]byte, 50000) - for remain > len(data) { + // Keep writing till the window drops below len(data). + for { c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -2345,21 +2437,25 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(remain>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Don't reduce window to zero here. + if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { + remain = wnd << ws + break + } } // Make the window non-zero, but the scaled window zero. - if remain >= 16 { + for remain >= 16 { data = data[:remain-15] c.SendPacket(data, &context.Headers{ SrcPort: context.TestPort, @@ -2370,22 +2466,35 @@ func TestZeroScaledWindowReceive(t *testing.T) { RcvWnd: 30000, }) sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(0), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Since the receive buffer is split between window advertisement and + // application data buffer the window does not always reflect the space + // available and actual space available can be a bit more than what is + // advertised in the window. + wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) + if wnd == 0 { + break + } + remain = wnd << ws } - // Read at least 1MSS of data. An ack should be sent in response to that. + // Read at least 2MSS of data. An ack should be sent in response to that. + // Since buffer space is now split in half between window and application + // data we need to read more than 1 MSS(65536) of data for a non-zero window + // update to be sent. For 1MSS worth of window to be available we need to + // read at least 128KB. Since our segments above were 50KB each it means + // we need to read at 3 packets. sz := 0 - for sz < defaultMTU { + for sz < defaultMTU*2 { v, _, err := c.EP.Read(nil) if err != nil { t.Fatalf("Read failed: %s", err) @@ -2397,9 +2506,9 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(sz>>ws)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -2466,8 +2575,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize+1), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+uint32(i)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2489,8 +2598,8 @@ func TestSegmentMerging(t *testing.T) { checker.PayloadLen(len(allData)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+11), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+11), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2537,8 +2646,8 @@ func TestDelay(t *testing.T) { checker.PayloadLen(len(want)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2584,8 +2693,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2607,8 +2716,8 @@ func TestUndelay(t *testing.T) { checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2669,8 +2778,8 @@ func TestMSSNotDelayed(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(seq)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(seq)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2721,8 +2830,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -2964,7 +3073,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - const wndScale = 2 + const wndScale = 3 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } @@ -2999,7 +3108,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), checker.SrcPort(tcpHdr.SourcePort()), - checker.SeqNum(tcpHdr.SequenceNumber()), + checker.TCPSeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -3020,8 +3129,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -3314,8 +3423,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3335,8 +3444,8 @@ func TestFinImmediately(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3357,8 +3466,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3368,8 +3477,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3389,8 +3498,8 @@ func TestFinRetransmit(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(791), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3413,8 +3522,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3438,8 +3547,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3460,8 +3569,8 @@ func TestFinWithNoPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3488,8 +3597,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3507,8 +3616,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3527,8 +3636,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3548,8 +3657,8 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3572,8 +3681,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3597,8 +3706,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3613,8 +3722,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3634,8 +3743,8 @@ func TestFinWithPendingData(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3659,8 +3768,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3680,8 +3789,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3695,8 +3804,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3711,8 +3820,8 @@ func TestFinWithPartialAck(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(791), + checker.TCPSeqNum(next), + checker.TCPAckNum(791), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -3803,8 +3912,8 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.PayloadLen((1<<scale)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -3942,7 +4051,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3993,8 +4102,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), ), ) @@ -4018,8 +4127,8 @@ func TestReadAfterClosedState(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+2), - checker.AckNum(uint32(791+len(data))), + checker.TCPSeqNum(uint32(c.IRS)+2), + checker.TCPAckNum(uint32(791+len(data))), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4280,14 +4389,14 @@ func TestMinMaxBufferSizes(t *testing.T) { } } - // Set values below the min. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { + // Set values below the min/2. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } @@ -4298,13 +4407,15 @@ func TestMinMaxBufferSizes(t *testing.T) { t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) + // Values above max are capped at max and then doubled. + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) + // Values above max are capped at max and then doubled. + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } func TestBindToDeviceOption(t *testing.T) { @@ -4646,8 +4757,8 @@ func TestPathMTUDiscovery(t *testing.T) { checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(seqNum), - checker.AckNum(790), + checker.TCPSeqNum(seqNum), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4898,8 +5009,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4932,8 +5043,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -4944,8 +5055,8 @@ func TestKeepalive(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), ), ) @@ -4970,8 +5081,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next-1)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(next-1)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -4997,8 +5108,8 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -5038,7 +5149,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -5082,7 +5193,7 @@ func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCoo checker.SrcPort(context.StackPort), checker.DstPort(srcPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } if synCookieInUse { @@ -5316,7 +5427,7 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5416,7 +5527,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) }) } } @@ -5464,7 +5575,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5642,7 +5753,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(irs) + 1), + checker.TCPAckNum(uint32(irs) + 1), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5663,8 +5774,8 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.AckNum(uint32(irs) + 1), - checker.SeqNum(uint32(iss + 1)), + checker.TCPAckNum(uint32(irs) + 1), + checker.TCPSeqNum(uint32(iss + 1)), } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) @@ -5962,16 +6073,14 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { time.Sleep(latency) rawEP.SendPacketWithTS([]byte{1}, tsVal) - // Verify that the ACK has the expected window. - wantRcvWnd := receiveBufferSize - wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale)) - rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. - b := make([]byte, int(receiveBufferSize)*2) - offset := 0 - payloadSize := receiveBufferSize - 1 + payloadSize := receiveBufferSize * 2 + b := make([]byte, int(payloadSize)) + worker := (c.EP).(interface { StopWork() ResumeWork() @@ -5980,11 +6089,15 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Stop the worker goroutine. worker.StopWork() - start := offset - end := offset + payloadSize + start := 0 + end := payloadSize / 2 packetsSent := 0 for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + packetEnd := start + mss + if start+mss > end { + packetEnd = end + } + rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) packetsSent++ } @@ -5992,29 +6105,20 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // are waiting to be read. worker.ResumeWork() - // Since we read no bytes the window should goto zero till the - // application reads some of the data. - // Discard all intermediate acks except the last one. - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() - } + // Since we sent almost the full receive buffer worth of data (some may have + // been dropped due to segment overheads), we should get a zero window back. + pkt = c.GetPacket() + tcpHdr := header.TCP(header.IPv4(pkt).Payload()) + gotRcvWnd := tcpHdr.WindowSize() + wantAckNum := tcpHdr.AckNumber() + if got, want := int(gotRcvWnd), 0; got != want { + t.Fatalf("got rcvWnd: %d, want: %d", got, want) } - rawEP.VerifyACKRcvWnd(0) time.Sleep(25 * time.Millisecond) - // Verify that sending more data when window is closed is dropped and - // not acked. + // Verify that sending more data when receiveBuffer is exhausted. rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - // Verify that the stack sends us back an ACK with the sequence number - // of the last packet sent indicating it was dropped. - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), - checker.Window(0), - )) - // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { @@ -6027,23 +6131,26 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Verify that we receive a non-zero window update ACK. When running // under thread santizer this test can end up sending more than 1 // ack, 1 for the non-zero window - p = c.GetPacket() + p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), + checker.TCPAckNum(uint32(wantAckNum)), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { return } - if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) + // We use 10% here as the error margin upwards as the initial window we + // got was afer 1 segment was already in the receive buffer queue. + tolerance := 1.1 + if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { + t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) } }, )) } -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. +// This test verifies that the advertised window is auto-tuned up as the +// application is reading the data that is being received. func TestReceiveBufferAutoTuning(t *testing.T) { const mtu = 1500 const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize @@ -6053,9 +6160,6 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Enable Auto-tuning. stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 { @@ -6077,8 +6181,10 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - wantRcvWnd := receiveBufferSize + tsVal := uint32(rawEP.TSVal) + rawEP.SendPacketWithTS([]byte{1}, tsVal) + pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) + curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { return uint16(rcvWnd >> uint16(c.WindowScale)) } @@ -6095,14 +6201,8 @@ func TestReceiveBufferAutoTuning(t *testing.T) { StopWork() ResumeWork() }) - tsVal := rawEP.TSVal - // We are going to do our own computation of what the moderated receive - // buffer should be based on sent/copied data per RTT and verify that - // the advertised window by the stack matches our calculations. - prevCopied := 0 - done := false latency := 1 * time.Millisecond - for i := 0; !done; i++ { + for i := 0; i < 5; i++ { tsVal++ // Stop the worker goroutine. @@ -6124,15 +6224,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Give 1ms for the worker to process the packets. time.Sleep(1 * time.Millisecond) - // Verify that the advertised window on the ACK is reduced by - // the total bytes sent. - expectedWnd := wantRcvWnd - totalSent - if packetsSent > 100 { - for i := 0; i < (packetsSent / 100); i++ { - _ = c.GetPacket() + lastACK := c.GetPacket() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break } + lastACK = pkt + } + if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { + t.Fatalf("advertised window got: %d, want <= %d", got, want) } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) // Now read all the data from the endpoint and invoke the // moderation API to allow for receive buffer auto-tuning @@ -6157,35 +6262,20 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ - if i == 0 { // In the first iteration the receiver based RTT is not // yet known as a result the moderation code should not // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) - prevCopied = totalCopied + rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) } else { - rttCopied := totalCopied - if i == 1 { - // The moderation code accumulates copied bytes till - // RTT is established. So add in the bytes sent in - // the first iteration to the total bytes for this - // RTT. - rttCopied += prevCopied - // Now reset it to the initial value used by the - // auto tuning logic. - prevCopied = tcp.InitialCwnd * mss * 2 - } - newWnd := rttCopied<<1 + 16*mss - grow := (newWnd * (rttCopied - prevCopied)) / prevCopied - newWnd += (grow << 1) - if newWnd > maxReceiveBufferSize { - newWnd = maxReceiveBufferSize - done = true + pkt := c.GetPacket() + curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale + // If thew new current window is close maxReceiveBufferSize then terminate + // the loop. This can happen before all iterations are done due to timing + // differences when running the test. + if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { + break } - rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) - wantRcvWnd = newWnd - prevCopied = rttCopied // Increase the latency after first two iterations to // establish a low RTT value in the receiver since it // only tracks the lowest value. This ensures that when @@ -6198,6 +6288,12 @@ func TestReceiveBufferAutoTuning(t *testing.T) { offset += payloadSize payloadSize *= 2 } + // Check that at the end of our iterations the receive window grew close to the maximum + // permissible size of maxReceiveBufferSize/2 + if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { + t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) + } + } func TestDelayEnabled(t *testing.T) { @@ -6349,8 +6445,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6367,8 +6463,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now send a RST and this should be ignored and not @@ -6396,8 +6492,8 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6468,8 +6564,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6486,8 +6582,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Out of order ACK should generate an immediate ACK in @@ -6503,8 +6599,8 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) } @@ -6575,8 +6671,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6593,8 +6689,8 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Send a SYN request w/ sequence number lower than @@ -6732,8 +6828,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) finHeaders := &context.Headers{ @@ -6750,8 +6846,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) time.Sleep(2 * time.Second) @@ -6765,8 +6861,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+2)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+2)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Sleep for 4 seconds so at this point we are 1 second past the @@ -6794,8 +6890,8 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { @@ -6894,8 +6990,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(iss)+2), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(iss)+2), checker.TCPFlags(header.TCPFlagAck))) // Now write a few bytes and then close the endpoint. @@ -6913,8 +7009,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -6928,8 +7024,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.AckNum(uint32(iss+2)), + checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.TCPAckNum(uint32(iss+2)), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) // First send a partial ACK. @@ -6974,8 +7070,8 @@ func TestTCPCloseWithData(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), - checker.SeqNum(uint32(ackHeaders.AckNum)), - checker.AckNum(0), + checker.TCPSeqNum(uint32(ackHeaders.AckNum)), + checker.TCPAckNum(0), checker.TCPFlags(header.TCPFlagRst))) } @@ -7011,8 +7107,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.PayloadLen(len(view)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(next), - checker.AckNum(790), + checker.TCPSeqNum(next), + checker.TCPAckNum(790), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -7046,8 +7142,8 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(next)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(next)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7108,8 +7204,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)), - checker.AckNum(uint32(790)), + checker.TCPSeqNum(uint32(c.IRS)), + checker.TCPAckNum(uint32(790)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7134,8 +7230,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS+1)), - checker.AckNum(uint32(0)), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), ) @@ -7151,9 +7247,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { } } -func TestIncreaseWindowOnReceive(t *testing.T) { +func TestIncreaseWindowOnRead(t *testing.T) { // This test ensures that the endpoint sends an ack, - // after recv() when the window grows to more than 1 MSS. + // after read() when the window grows by more than 1 MSS. c := context.New(t, defaultMTU) defer c.Cleanup() @@ -7162,10 +7258,9 @@ func TestIncreaseWindowOnReceive(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7178,46 +7273,43 @@ func TestIncreaseWindowOnReceive(t *testing.T) { }) sent += len(data) remain -= len(data) - - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } - checker.IPv4(t, c.GetPacket(), + pkt := c.GetPacket() + checker.IPv4(t, pkt, checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), checker.TCPFlags(header.TCPFlagAck), ), ) + // Break once the window drops below defaultMTU/2 + if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { + break + } } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - - // We now have < 1 MSS in the buffer space. Read the data! An - // ack should be sent in response to that. The window was not - // zero, but it grew to larger than MSS. - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) - } - - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("Read failed: %s", err) + // We now have < 1 MSS in the buffer space. Read at least > 2 MSS + // worth of data as receive buffer space + read := 0 + // defaultMTU is a good enough estimate for the MSS used for this + // connection. + for read < defaultMTU*2 { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + read += len(v) } - // After reading two packets, we surely crossed MSS. See the ack: + // After reading > MSS worth of data, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7234,10 +7326,9 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf + remain := rcvBuf * 2 sent := 0 data := make([]byte, defaultMTU/2) - lastWnd := uint16(0) for remain > len(data) { c.SendPacket(data, &context.Headers{ @@ -7251,38 +7342,29 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { sent += len(data) remain -= len(data) - lastWnd = uint16(remain) - if remain > 0xffff { - lastWnd = 0xffff - } checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(lastWnd), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindowLessThanEq(0xffff), checker.TCPFlags(header.TCPFlagAck), ), ) } - if lastWnd == 0xffff || lastWnd == 0 { - t.Fatalf("expected small, non-zero window: %d", lastWnd) - } - // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - // After reading two packets, we surely crossed MSS. See the ack: checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(790+sent)), - checker.Window(uint16(0xffff)), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(790+sent)), + checker.TCPWindow(uint16(0xffff)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -7327,8 +7409,8 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) @@ -7342,8 +7424,8 @@ func TestTCPDeferAccept(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestTCPDeferAcceptTimeout(t *testing.T) { @@ -7380,7 +7462,7 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.AckNum(uint32(irs)+1))) + checker.TCPAckNum(uint32(irs)+1))) // Send data. This should result in an acceptable endpoint. c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ @@ -7396,8 +7478,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) @@ -7412,8 +7494,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SrcPort(context.StackPort), checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.SeqNum(uint32(iss+1)), - checker.AckNum(uint32(irs+5)))) + checker.TCPSeqNum(uint32(iss+1)), + checker.TCPAckNum(uint32(irs+5)))) } func TestResetDuringClose(t *testing.T) { @@ -7438,8 +7520,8 @@ func TestResetDuringClose(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(irs.Add(1))), - checker.AckNum(uint32(iss.Add(5))))) + checker.TCPSeqNum(uint32(irs.Add(1))), + checker.TCPAckNum(uint32(iss.Add(5))))) // Close in a separate goroutine so that we can trigger // a race with the RST we send below. This should not @@ -7520,3 +7602,14 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } } } + +// generateRandomPayload generates a random byte slice of the specified length +// causing a fatal test failure if it is unable to do so. +func generateRandomPayload(t *testing.T, n int) []byte { + t.Helper() + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + return buf +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 44593ed98..0f9ed06cd 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -159,9 +159,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.PayloadLen(len(data)+header.TCPMinimumSize+12), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(true, 0, tsVal+1), ), @@ -181,7 +181,8 @@ func TestTimeStampEnabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) @@ -219,9 +220,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.PayloadLen(len(data)+header.TCPMinimumSize), checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(790), - checker.Window(wndSize), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(790), + checker.TCPWindow(wndSize), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), checker.TCPTimestampChecker(false, 0, 0), ), @@ -237,7 +238,9 @@ func TestTimeStampDisabledAccept(t *testing.T) { wndSize uint16 }{ {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of + // that. + {false, 5, 0x4000}, } for _, tc := range testCases { timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 85e8c1c75..ebbae6e2f 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -145,6 +145,10 @@ type Context struct { // WindowScale is the expected window scale in SYN packets sent by // the stack. WindowScale uint8 + + // RcvdWindowScale is the actual window scale sent by the stack in + // SYN/SYN-ACK. + RcvdWindowScale uint8 } // New allocates and initializes a test context containing a new @@ -261,18 +265,17 @@ func (c *Context) CheckNoPacket(errMsg string) { c.CheckNoPacketTimeout(errMsg, 1*time.Second) } -// GetPacket reads a packet from the link layer endpoint and verifies +// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies // that it is an IPv4 packet with the expected source and destination -// addresses. It will fail with an error if no packet is received for -// 2 seconds. -func (c *Context) GetPacket() []byte { +// addresses. If no packet is received in the specified timeout it will return +// nil. +func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { - c.t.Fatalf("Packet wasn't written out") return nil } @@ -280,6 +283,14 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -291,6 +302,21 @@ func (c *Context) GetPacket() []byte { return b } +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. +func (c *Context) GetPacket() []byte { + c.t.Helper() + + p := c.GetPacketWithTimeout(5 * time.Second) + if p == nil { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + return p +} + // GetPacketNonBlocking reads a packet from the link layer endpoint // and verifies that it is an IPv4 packet with the expected source // and destination address. If no packet is available it will return @@ -307,6 +333,14 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } + // Just check that the stack set the transport protocol number for outbound + // TCP messages. + // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part + // of the headerinfo. + if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { + c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() @@ -470,8 +504,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.PayloadLen(size+header.TCPMinimumSize+optlen), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -497,8 +531,8 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.PayloadLen(size+header.TCPMinimumSize), checker.TCP( checker.DstPort(TestPort), - checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -636,6 +670,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) } tcpHdr := header.TCP(header.IPv4(b).Payload()) + synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ @@ -653,8 +688,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) checker.TCP( checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), + checker.TCPSeqNum(uint32(c.IRS)+1), + checker.TCPAckNum(uint32(iss)+1), ), ) @@ -671,6 +706,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) } + c.RcvdWindowScale = uint8(synOpts.WS) c.Port = tcpHdr.SourcePort() } @@ -742,17 +778,18 @@ func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) } -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { +// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches +// the provided tsVal as well as returns the original packet. +func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { + r.C.t.Helper() // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPTimestampChecker(true, 0, tsVal), ), ) @@ -760,19 +797,28 @@ func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) opts := tcpSeg.ParsedOptions() r.RecentTS = opts.TSVal + return ackPacket +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + r.C.t.Helper() + _ = r.VerifyAndReturnACKWithTS(tsVal) } // VerifyACKRcvWnd verifies that the window advertised by the incoming ACK // matches the provided rcvWnd. func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { + r.C.t.Helper() ackPacket := r.C.GetPacket() checker.IPv4(r.C.t, ackPacket, checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), - checker.Window(rcvWnd), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), + checker.TCPWindow(rcvWnd), ), ) } @@ -791,8 +837,8 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { checker.TCP( checker.DstPort(r.SrcPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(r.AckNum)), - checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSeqNum(uint32(r.AckNum)), + checker.TCPAckNum(uint32(r.NextSeqNum)), checker.TCPSACKBlockChecker(sackBlocks), ), ) @@ -884,8 +930,8 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * tcpCheckers := []checker.TransportChecker{ checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS) + 1), - checker.AckNum(uint32(iss) + 1), + checker.TCPSeqNum(uint32(c.IRS) + 1), + checker.TCPAckNum(uint32(iss) + 1), } // Verify that tsEcr of ACK packet is wantOptions.TSVal if the @@ -920,7 +966,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Mark in context that timestamp option is enabled for this endpoint. c.TimeStampEnabled = true - + c.RcvdWindowScale = uint8(synOptions.WS) return &RawEndpoint{ C: c, SrcPort: tcpSeg.DestinationPort(), @@ -1013,6 +1059,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP // value of the window scaling option to be sent in the SYN. If synOptions.WS > // 0 then we send the WindowScale option. func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + c.t.Helper() opts := make([]byte, header.TCPOptionsMaximumSize) offset := 0 offset += header.EncodeMSSOption(uint32(maxPayload), opts) @@ -1051,13 +1098,14 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // are present. b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) + rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) c.IRS = seqnum.Value(tcp.SequenceNumber()) tcpCheckers := []checker.TransportChecker{ checker.SrcPort(StackPort), checker.DstPort(TestPort), checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.AckNum(uint32(iss) + 1), + checker.TCPAckNum(uint32(iss) + 1), checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), } @@ -1100,6 +1148,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // Send ACK. c.SendPacket(nil, ackHeaders) + c.RcvdWindowScale = uint8(rcvdSynOptions.WS) c.Port = StackPort return &RawEndpoint{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 518f636f0..086d0bdbc 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -996,6 +996,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // Initialize the UDP header. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + pkt.TransportProtocolNumber = ProtocolNumber length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 7d6b91a75..a1d0f49d9 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -80,126 +80,21 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } -// HandleUnknownDestinationPacket handles packets targeted at this protocol but -// that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +// HandleUnknownDestinationPacket handles packets that are targeted at this +// protocol but don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true + return stack.UnknownDestinationPacketMalformed } if !verifyChecksum(r, hdr, pkt) { - // Checksum Error. r.Stack().Stats().UDP.ChecksumErrors.Increment() - return true + return stack.UnknownDestinationPacketMalformed } - // Only send ICMP error if the address is not a multicast/broadcast - // v4/v6 address or the source is not the unspecified address. - // - // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { - return true - } - - // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination - // Unreachable messages with code: - // - // 2 (Protocol Unreachable), when the designated transport protocol - // is not supported; or - // - // 3 (Port Unreachable), when the designated transport protocol - // (e.g., UDP) is unable to demultiplex the datagram but has no - // protocol mechanism to inform the sender. - switch len(id.LocalAddress) { - case header.IPv4AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() - return true - } - // As per RFC 1812 Section 4.3.2.3 - // - // ICMP datagram SHOULD contain as much of the original - // datagram as possible without the length of the ICMP - // datagram exceeding 576 bytes - // - // NOTE: The above RFC referenced is different from the original - // recommendation in RFC 1122 where it mentioned that at least 8 - // bytes of the payload must be included. Today linux and other - // systems implement the] RFC1812 definition and not the original - // RFC 1122 requirement. - mtu := int(r.MTU()) - if mtu > header.IPv4MinimumProcessableDatagramSize { - mtu = header.IPv4MinimumProcessableDatagramSize - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize - available := int(mtu) - headerLen - payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - - // The buffers used by pkt may be used elsewhere in the system. - // For example, a raw or packet socket may use what UDP - // considers an unreachable destination. Thus we deep copy pkt - // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) - newHeader = append(newHeader, pkt.TransportHeader().View()...) - payload := newHeader.ToVectorisedView() - payload.AppendView(pkt.Data.ToView()) - payload.CapLength(payloadLen) - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, - Data: payload, - }) - icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) - - case header.IPv6AddressSize: - if !r.Stack().AllowICMPMessage() { - r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() - return true - } - - // As per RFC 4443 section 2.4 - // - // (c) Every ICMPv6 error message (type < 128) MUST include - // as much of the IPv6 offending (invoking) packet (the - // packet that caused the error) as possible without making - // the error message packet exceed the minimum IPv6 MTU - // [IPv6]. - mtu := int(r.MTU()) - if mtu > header.IPv6MinimumMTU { - mtu = header.IPv6MinimumMTU - } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize - available := int(mtu) - headerLen - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - payloadLen := len(network) + len(transport) + pkt.Data.Size() - if payloadLen > available { - payloadLen = available - } - payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) - payload.Append(pkt.Data) - payload.CapLength(payloadLen) - - icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: headerLen, - Data: payload, - }) - icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) - } - return true + return stack.UnknownDestinationPacketUnhandled } // SetOption implements stack.TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index d5881d183..64a5fc696 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -388,6 +388,10 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } + if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { + c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) + } + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() |