summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/BUILD13
-rw-r--r--pkg/tcpip/transport/icmp/BUILD61
-rw-r--r--pkg/tcpip/transport/icmp/icmp_packet_list.go221
-rw-r--r--pkg/tcpip/transport/icmp/icmp_state_autogen.go167
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go239
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD46
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go318
-rw-r--r--pkg/tcpip/transport/internal/network/network_state_autogen.go110
-rw-r--r--pkg/tcpip/transport/packet/BUILD37
-rw-r--r--pkg/tcpip/transport/packet/packet_list.go221
-rw-r--r--pkg/tcpip/transport/packet/packet_state_autogen.go170
-rw-r--r--pkg/tcpip/transport/raw/BUILD41
-rw-r--r--pkg/tcpip/transport/raw/raw_packet_list.go221
-rw-r--r--pkg/tcpip/transport/raw/raw_state_autogen.go161
-rw-r--r--pkg/tcpip/transport/tcp/BUILD132
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go650
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard_test.go249
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go69
-rw-r--r--pkg/tcpip/transport/tcp/tcp_endpoint_list.go221
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go559
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go1101
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go704
-rw-r--r--pkg/tcpip/transport/tcp/tcp_segment_list.go221
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go901
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go8602
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go311
-rw-r--r--pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD26
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go1268
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go50
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD23
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go511
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/transport_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/udp/BUILD68
-rw-r--r--pkg/tcpip/transport/udp/udp_packet_list.go221
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go197
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go2602
39 files changed, 3041 insertions, 17754 deletions
diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD
deleted file mode 100644
index af332ed91..000000000
--- a/pkg/tcpip/transport/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "transport",
- srcs = [
- "datagram.go",
- "transport.go",
- ],
- visibility = ["//visibility:public"],
- deps = ["//pkg/tcpip"],
-)
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
deleted file mode 100644
index 4718ec4ec..000000000
--- a/pkg/tcpip/transport/icmp/BUILD
+++ /dev/null
@@ -1,61 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "icmp_packet_list",
- out = "icmp_packet_list.go",
- package = "icmp",
- prefix = "icmpPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*icmpPacket",
- "Linker": "*icmpPacket",
- },
-)
-
-go_library(
- name = "icmp",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "icmp_packet_list.go",
- "protocol.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport",
- "//pkg/tcpip/transport/internal/network",
- "//pkg/tcpip/transport/raw",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "icmp_x_test",
- size = "small",
- srcs = ["icmp_test.go"],
- deps = [
- ":icmp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go
new file mode 100644
index 000000000..0aacdad3f
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go
@@ -0,0 +1,221 @@
+package icmp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type icmpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type icmpPacketList struct {
+ head *icmpPacket
+ tail *icmpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *icmpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *icmpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *icmpPacketList) Front() *icmpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *icmpPacketList) Back() *icmpPacket {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *icmpPacketList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (icmpPacketElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *icmpPacketList) PushFront(e *icmpPacket) {
+ linker := icmpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *icmpPacketList) PushBack(e *icmpPacket) {
+ linker := icmpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *icmpPacketList) PushBackList(m *icmpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) {
+ bLinker := icmpPacketElementMapper{}.linkerFor(b)
+ eLinker := icmpPacketElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) {
+ aLinker := icmpPacketElementMapper{}.linkerFor(a)
+ eLinker := icmpPacketElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *icmpPacketList) Remove(e *icmpPacket) {
+ linker := icmpPacketElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ icmpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type icmpPacketEntry struct {
+ next *icmpPacket
+ prev *icmpPacket
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) Next() *icmpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) Prev() *icmpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) SetNext(elem *icmpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
new file mode 100644
index 000000000..5ba2d9925
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
@@ -0,0 +1,167 @@
+// automatically generated by stateify.
+
+package icmp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *icmpPacket) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.icmpPacket"
+}
+
+func (p *icmpPacket) StateFields() []string {
+ return []string{
+ "icmpPacketEntry",
+ "senderAddress",
+ "data",
+ "receivedAt",
+ }
+}
+
+func (p *icmpPacket) beforeSave() {}
+
+// +checklocksignore
+func (p *icmpPacket) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(2, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(3, receivedAtValue)
+ stateSinkObject.Save(0, &p.icmpPacketEntry)
+ stateSinkObject.Save(1, &p.senderAddress)
+}
+
+func (p *icmpPacket) afterLoad() {}
+
+// +checklocksignore
+func (p *icmpPacket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.icmpPacketEntry)
+ stateSourceObject.Load(1, &p.senderAddress)
+ stateSourceObject.LoadValue(2, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(3, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "DefaultSocketOptionsHandler",
+ "transProto",
+ "waiterQueue",
+ "uniqueID",
+ "net",
+ "stats",
+ "ops",
+ "rcvReady",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "frozen",
+ "ident",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &e.transProto)
+ stateSinkObject.Save(2, &e.waiterQueue)
+ stateSinkObject.Save(3, &e.uniqueID)
+ stateSinkObject.Save(4, &e.net)
+ stateSinkObject.Save(5, &e.stats)
+ stateSinkObject.Save(6, &e.ops)
+ stateSinkObject.Save(7, &e.rcvReady)
+ stateSinkObject.Save(8, &e.rcvList)
+ stateSinkObject.Save(9, &e.rcvBufSize)
+ stateSinkObject.Save(10, &e.rcvClosed)
+ stateSinkObject.Save(11, &e.frozen)
+ stateSinkObject.Save(12, &e.ident)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &e.transProto)
+ stateSourceObject.Load(2, &e.waiterQueue)
+ stateSourceObject.Load(3, &e.uniqueID)
+ stateSourceObject.Load(4, &e.net)
+ stateSourceObject.Load(5, &e.stats)
+ stateSourceObject.Load(6, &e.ops)
+ stateSourceObject.Load(7, &e.rcvReady)
+ stateSourceObject.Load(8, &e.rcvList)
+ stateSourceObject.Load(9, &e.rcvBufSize)
+ stateSourceObject.Load(10, &e.rcvClosed)
+ stateSourceObject.Load(11, &e.frozen)
+ stateSourceObject.Load(12, &e.ident)
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (l *icmpPacketList) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.icmpPacketList"
+}
+
+func (l *icmpPacketList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *icmpPacketList) beforeSave() {}
+
+// +checklocksignore
+func (l *icmpPacketList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *icmpPacketList) afterLoad() {}
+
+// +checklocksignore
+func (l *icmpPacketList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *icmpPacketEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.icmpPacketEntry"
+}
+
+func (e *icmpPacketEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *icmpPacketEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *icmpPacketEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *icmpPacketEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *icmpPacketEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*icmpPacket)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*icmpPacketList)(nil))
+ state.Register((*icmpPacketEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
deleted file mode 100644
index 729f50e9a..000000000
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ /dev/null
@@ -1,239 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package icmp_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package.
-// See the issue for remaining areas of work.
-
-var (
- localV4Addr1 = testutil.MustParse4("10.0.0.1")
- localV4Addr2 = testutil.MustParse4("10.0.0.2")
- remoteV4Addr = testutil.MustParse4("10.0.0.3")
-)
-
-func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint {
- t.Helper()
-
- ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */)
- t.Cleanup(ep.Close)
-
- wep := stack.LinkEndpoint(ep)
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
-
- opts := stack.NICOptions{Name: name}
- if err := s.CreateNICWithOptions(id, wep, opts); err != nil {
- t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addrV4.WithPrefix(),
- }
- if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
- }
-
- s.AddRoute(tcpip.Route{
- Destination: header.IPv4EmptySubnet,
- NIC: id,
- })
-
- return ep
-}
-
-func writePayload(buf []byte) {
- for i := range buf {
- buf[i] = byte(i)
- }
-}
-
-func newICMPv4EchoRequest(payloadSize uint32) buffer.View {
- buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize))
- writePayload(buf[header.ICMPv4MinimumSize:])
-
- icmp := header.ICMPv4(buf)
- icmp.SetType(header.ICMPv4Echo)
- // No need to set the checksum; it is reset by the socket before the packet
- // is sent.
-
- return buf
-}
-
-// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket
-// when SO_BINDTODEVICE is set to the non-default NIC for that subnet.
-//
-// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to
-// the version of IP.
-func TestWriteUnboundWithBindToDevice(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- HandleLocal: true,
- })
-
- // Add two NICs, both with default routes on the same subnet. The first NIC
- // added will be the default NIC for that subnet.
- defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1)
- alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2)
-
- socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err)
- }
- defer socket.Close()
-
- echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
-
- // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC
- // to be added is the default NIC to send packets when not explicitly bound.
- {
- buf := newICMPv4EchoRequest(echoPayloadSize)
- r := buf.Reader()
- n, err := socket.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: remoteV4Addr},
- })
- if err != nil {
- t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err)
- }
- if n != int64(len(buf)) {
- t.Fatalf("got n = %d, want n = %d", n, len(buf))
- }
-
- // Verify the packet was sent out the default NIC.
- p, ok := defaultEP.Read()
- if !ok {
- t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(t, b, []checker.NetworkChecker{
- checker.SrcAddr(localV4Addr1),
- checker.DstAddr(remoteV4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
- ),
- }...)
-
- // Verify the packet was not sent out the alternate NIC.
- if p, ok := alternateEP.Read(); ok {
- t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
- }
- }
-
- // Send a packet with SO_BINDTODEVICE. This exercises reliance on
- // SO_BINDTODEVICE to route the packet to the alternate NIC.
- {
- // Use SO_BINDTODEVICE to send over the alternate NIC by default.
- socket.SocketOptions().SetBindToDevice(2)
-
- buf := newICMPv4EchoRequest(echoPayloadSize)
- r := buf.Reader()
- n, err := socket.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: remoteV4Addr},
- })
- if err != nil {
- t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
- }
- if n != int64(len(buf)) {
- t.Fatalf("got n = %d, want n = %d", n, len(buf))
- }
-
- // Verify the packet was not sent out the default NIC.
- if p, ok := defaultEP.Read(); ok {
- t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p)
- }
-
- // Verify the packet was sent out the alternate NIC.
- p, ok := alternateEP.Read()
- if !ok {
- t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(t, b, []checker.NetworkChecker{
- checker.SrcAddr(localV4Addr2),
- checker.DstAddr(remoteV4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
- ),
- }...)
- }
-
- // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing
- // the device binding will fallback to using the default NIC to send
- // packets.
- {
- socket.SocketOptions().SetBindToDevice(0)
-
- buf := newICMPv4EchoRequest(echoPayloadSize)
- r := buf.Reader()
- n, err := socket.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: remoteV4Addr},
- })
- if err != nil {
- t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
- }
- if n != int64(len(buf)) {
- t.Fatalf("got n = %d, want n = %d", n, len(buf))
- }
-
- // Verify the packet was sent out the default NIC.
- p, ok := defaultEP.Read()
- if !ok {
- t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(t, b, []checker.NetworkChecker{
- checker.SrcAddr(localV4Addr1),
- checker.DstAddr(remoteV4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
- ),
- }...)
-
- // Verify the packet was not sent out the alternate NIC.
- if p, ok := alternateEP.Read(); ok {
- t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
- }
- }
-}
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
deleted file mode 100644
index 3818cb04e..000000000
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ /dev/null
@@ -1,46 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "network",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- ],
- visibility = [
- "//pkg/tcpip/transport/icmp:__pkg__",
- "//pkg/tcpip/transport/raw:__pkg__",
- "//pkg/tcpip/transport/udp:__pkg__",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport",
- ],
-)
-
-go_test(
- name = "network_test",
- size = "small",
- srcs = ["endpoint_test.go"],
- deps = [
- ":network",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport",
- "//pkg/tcpip/transport/udp",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
deleted file mode 100644
index f263a9ea2..000000000
--- a/pkg/tcpip/transport/internal/network/endpoint_test.go
+++ /dev/null
@@ -1,318 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package network_test
-
-import (
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport"
- "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
-)
-
-var (
- ipv4NICAddr = testutil.MustParse4("1.2.3.4")
- ipv6NICAddr = testutil.MustParse6("a::1")
- ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
- ipv6RemoteAddr = testutil.MustParse6("b::1")
-)
-
-func TestEndpointStateTransitions(t *testing.T) {
- const nicID = 1
-
- data := buffer.View([]byte{1, 2, 4, 5})
- v4Checker := func(t *testing.T, b buffer.View) {
- checker.IPv4(t, b,
- checker.SrcAddr(ipv4NICAddr),
- checker.DstAddr(ipv4RemoteAddr),
- checker.IPPayload(data),
- )
- }
-
- v6Checker := func(t *testing.T, b buffer.View) {
- checker.IPv6(t, b,
- checker.SrcAddr(ipv6NICAddr),
- checker.DstAddr(ipv6RemoteAddr),
- checker.IPPayload(data),
- )
- }
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- expectedMaxHeaderLength uint16
- expectedNetProto tcpip.NetworkProtocolNumber
- expectedLocalAddr tcpip.Address
- bindAddr tcpip.Address
- expectedBoundAddr tcpip.Address
- remoteAddr tcpip.Address
- expectedRemoteAddr tcpip.Address
- checker func(*testing.T, buffer.View)
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
- expectedNetProto: ipv4.ProtocolNumber,
- expectedLocalAddr: ipv4NICAddr,
- bindAddr: header.IPv4AllSystems,
- expectedBoundAddr: header.IPv4AllSystems,
- remoteAddr: ipv4RemoteAddr,
- expectedRemoteAddr: ipv4RemoteAddr,
- checker: v4Checker,
- },
- {
- name: "IPv6",
- netProto: ipv6.ProtocolNumber,
- expectedMaxHeaderLength: header.IPv6FixedHeaderSize,
- expectedNetProto: ipv6.ProtocolNumber,
- expectedLocalAddr: ipv6NICAddr,
- bindAddr: header.IPv6AllNodesMulticastAddress,
- expectedBoundAddr: header.IPv6AllNodesMulticastAddress,
- remoteAddr: ipv6RemoteAddr,
- expectedRemoteAddr: ipv6RemoteAddr,
- checker: v6Checker,
- },
- {
- name: "IPv4-mapped-IPv6",
- netProto: ipv6.ProtocolNumber,
- expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
- expectedNetProto: ipv4.ProtocolNumber,
- expectedLocalAddr: ipv4NICAddr,
- bindAddr: testutil.MustParse6("::ffff:e000:0001"),
- expectedBoundAddr: header.IPv4AllSystems,
- remoteAddr: testutil.MustParse6("::ffff:0607:0809"),
- expectedRemoteAddr: ipv4RemoteAddr,
- checker: v4Checker,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ipv4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ipv4NICAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
- }
- ipv6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: ipv6NICAddr.WithPrefix(),
- }
-
- if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
- {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
- })
-
- var ops tcpip.SocketOptions
- var ep network.Endpoint
- ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
- defer ep.Close()
- if state := ep.State(); state != transport.DatagramEndpointStateInitial {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
- }
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
- if state := ep.State(); state != transport.DatagramEndpointStateBound {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound)
- }
- if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" {
- t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
- }
- if addr, connected := ep.GetRemoteAddress(); connected {
- t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr)
- }
-
- connectAddr := tcpip.FullAddress{Addr: test.remoteAddr}
- if err := ep.Connect(connectAddr); err != nil {
- t.Fatalf("ep.Connect(%#v): %s", connectAddr, err)
- }
- if state := ep.State(); state != transport.DatagramEndpointStateConnected {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected)
- }
- if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" {
- t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
- }
- if addr, connected := ep.GetRemoteAddress(); !connected {
- t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr)
- } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" {
- t.Errorf("remote address mismatch (-want +got):\n%s", diff)
- }
-
- ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("ep.AcquireContexForWrite({}): %s", err)
- }
- defer ctx.Release()
- info := ctx.PacketInfo()
- if diff := cmp.Diff(network.WritePacketInfo{
- NetProto: test.expectedNetProto,
- LocalAddress: test.expectedLocalAddr,
- RemoteAddress: test.expectedRemoteAddr,
- MaxHeaderLength: test.expectedMaxHeaderLength,
- RequiresTXTransportChecksum: true,
- }, info); diff != "" {
- t.Errorf("write packet info mismatch (-want +got):\n%s", diff)
- }
- if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(info.MaxHeaderLength),
- Data: data.ToVectorisedView(),
- }), false /* headerIncluded */); err != nil {
- t.Fatalf("ctx.WritePacket(_, false): %s", err)
- }
- if pkt, ok := e.Read(); !ok {
- t.Fatalf("expected packet to be read from link endpoint")
- } else {
- test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- }
-
- ep.Close()
- if state := ep.State(); state != transport.DatagramEndpointStateClosed {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed)
- }
- })
- }
-}
-
-func TestBindNICID(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- bindAddr tcpip.Address
- unicast bool
- }{
- {
- name: "IPv4 multicast",
- netProto: ipv4.ProtocolNumber,
- bindAddr: header.IPv4AllSystems,
- unicast: false,
- },
- {
- name: "IPv6 multicast",
- netProto: ipv6.ProtocolNumber,
- bindAddr: header.IPv6AllNodesMulticastAddress,
- unicast: false,
- },
- {
- name: "IPv4 unicast",
- netProto: ipv4.ProtocolNumber,
- bindAddr: ipv4NICAddr,
- unicast: true,
- },
- {
- name: "IPv6 unicast",
- netProto: ipv6.ProtocolNumber,
- bindAddr: ipv6NICAddr,
- unicast: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, testBindNICID := range []tcpip.NICID{0, nicID} {
- t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ipv4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ipv4NICAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
- }
- ipv6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: ipv6NICAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
- }
-
- var ops tcpip.SocketOptions
- var ep network.Endpoint
- ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
- defer ep.Close()
- if ep.WasBound() {
- t.Fatal("got ep.WasBound() = true, want = false")
- }
- wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
- if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
- t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
- }
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
- if !ep.WasBound() {
- t.Error("got ep.WasBound() = false, want = true")
- }
- wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
- wantInfo.BindAddr = bindAddr.Addr
- wantInfo.BindNICID = bindAddr.NIC
- if test.unicast {
- wantInfo.RegisterNICID = nicID
- } else {
- wantInfo.RegisterNICID = bindAddr.NIC
- }
- if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
- t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/transport/internal/network/network_state_autogen.go b/pkg/tcpip/transport/internal/network/network_state_autogen.go
new file mode 100644
index 000000000..1515c8632
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/network_state_autogen.go
@@ -0,0 +1,110 @@
+// automatically generated by stateify.
+
+package network
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (e *Endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/internal/network.Endpoint"
+}
+
+func (e *Endpoint) StateFields() []string {
+ return []string{
+ "ops",
+ "netProto",
+ "transProto",
+ "wasBound",
+ "owner",
+ "writeShutdown",
+ "effectiveNetProto",
+ "multicastMemberships",
+ "ttl",
+ "multicastTTL",
+ "multicastAddr",
+ "multicastNICID",
+ "ipv4TOS",
+ "ipv6TClass",
+ "info",
+ "state",
+ }
+}
+
+func (e *Endpoint) beforeSave() {}
+
+// +checklocksignore
+func (e *Endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.ops)
+ stateSinkObject.Save(1, &e.netProto)
+ stateSinkObject.Save(2, &e.transProto)
+ stateSinkObject.Save(3, &e.wasBound)
+ stateSinkObject.Save(4, &e.owner)
+ stateSinkObject.Save(5, &e.writeShutdown)
+ stateSinkObject.Save(6, &e.effectiveNetProto)
+ stateSinkObject.Save(7, &e.multicastMemberships)
+ stateSinkObject.Save(8, &e.ttl)
+ stateSinkObject.Save(9, &e.multicastTTL)
+ stateSinkObject.Save(10, &e.multicastAddr)
+ stateSinkObject.Save(11, &e.multicastNICID)
+ stateSinkObject.Save(12, &e.ipv4TOS)
+ stateSinkObject.Save(13, &e.ipv6TClass)
+ stateSinkObject.Save(14, &e.info)
+ stateSinkObject.Save(15, &e.state)
+}
+
+func (e *Endpoint) afterLoad() {}
+
+// +checklocksignore
+func (e *Endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.ops)
+ stateSourceObject.Load(1, &e.netProto)
+ stateSourceObject.Load(2, &e.transProto)
+ stateSourceObject.Load(3, &e.wasBound)
+ stateSourceObject.Load(4, &e.owner)
+ stateSourceObject.Load(5, &e.writeShutdown)
+ stateSourceObject.Load(6, &e.effectiveNetProto)
+ stateSourceObject.Load(7, &e.multicastMemberships)
+ stateSourceObject.Load(8, &e.ttl)
+ stateSourceObject.Load(9, &e.multicastTTL)
+ stateSourceObject.Load(10, &e.multicastAddr)
+ stateSourceObject.Load(11, &e.multicastNICID)
+ stateSourceObject.Load(12, &e.ipv4TOS)
+ stateSourceObject.Load(13, &e.ipv6TClass)
+ stateSourceObject.Load(14, &e.info)
+ stateSourceObject.Load(15, &e.state)
+}
+
+func (m *multicastMembership) StateTypeName() string {
+ return "pkg/tcpip/transport/internal/network.multicastMembership"
+}
+
+func (m *multicastMembership) StateFields() []string {
+ return []string{
+ "nicID",
+ "multicastAddr",
+ }
+}
+
+func (m *multicastMembership) beforeSave() {}
+
+// +checklocksignore
+func (m *multicastMembership) StateSave(stateSinkObject state.Sink) {
+ m.beforeSave()
+ stateSinkObject.Save(0, &m.nicID)
+ stateSinkObject.Save(1, &m.multicastAddr)
+}
+
+func (m *multicastMembership) afterLoad() {}
+
+// +checklocksignore
+func (m *multicastMembership) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &m.nicID)
+ stateSourceObject.Load(1, &m.multicastAddr)
+}
+
+func init() {
+ state.Register((*Endpoint)(nil))
+ state.Register((*multicastMembership)(nil))
+}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
deleted file mode 100644
index b989b1209..000000000
--- a/pkg/tcpip/transport/packet/BUILD
+++ /dev/null
@@ -1,37 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "packet_list",
- out = "packet_list.go",
- package = "packet",
- prefix = "packet",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*packet",
- "Linker": "*packet",
- },
-)
-
-go_library(
- name = "packet",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go
new file mode 100644
index 000000000..2c983aad0
--- /dev/null
+++ b/pkg/tcpip/transport/packet/packet_list.go
@@ -0,0 +1,221 @@
+package packet
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type packetElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (packetElementMapper) linkerFor(elem *packet) *packet { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type packetList struct {
+ head *packet
+ tail *packet
+}
+
+// Reset resets list l to the empty state.
+func (l *packetList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *packetList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *packetList) Front() *packet {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *packetList) Back() *packet {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *packetList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (packetElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *packetList) PushFront(e *packet) {
+ linker := packetElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ packetElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *packetList) PushBack(e *packet) {
+ linker := packetElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *packetList) PushBackList(m *packetList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *packetList) InsertAfter(b, e *packet) {
+ bLinker := packetElementMapper{}.linkerFor(b)
+ eLinker := packetElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *packetList) InsertBefore(a, e *packet) {
+ aLinker := packetElementMapper{}.linkerFor(a)
+ eLinker := packetElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *packetList) Remove(e *packet) {
+ linker := packetElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ packetElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ packetElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type packetEntry struct {
+ next *packet
+ prev *packet
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *packetEntry) Next() *packet {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *packetEntry) Prev() *packet {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *packetEntry) SetNext(elem *packet) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *packetEntry) SetPrev(elem *packet) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go
new file mode 100644
index 000000000..b1e685bb4
--- /dev/null
+++ b/pkg/tcpip/transport/packet/packet_state_autogen.go
@@ -0,0 +1,170 @@
+// automatically generated by stateify.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *packet) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.packet"
+}
+
+func (p *packet) StateFields() []string {
+ return []string{
+ "packetEntry",
+ "data",
+ "receivedAt",
+ "senderAddr",
+ "packetInfo",
+ }
+}
+
+func (p *packet) beforeSave() {}
+
+// +checklocksignore
+func (p *packet) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(1, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(2, receivedAtValue)
+ stateSinkObject.Save(0, &p.packetEntry)
+ stateSinkObject.Save(3, &p.senderAddr)
+ stateSinkObject.Save(4, &p.packetInfo)
+}
+
+func (p *packet) afterLoad() {}
+
+// +checklocksignore
+func (p *packet) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.packetEntry)
+ stateSourceObject.Load(3, &p.senderAddr)
+ stateSourceObject.Load(4, &p.packetInfo)
+ stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(2, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (ep *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.endpoint"
+}
+
+func (ep *endpoint) StateFields() []string {
+ return []string{
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "cooked",
+ "ops",
+ "stats",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "rcvDisabled",
+ "closed",
+ "boundNetProto",
+ "boundNIC",
+ "lastError",
+ }
+}
+
+// +checklocksignore
+func (ep *endpoint) StateSave(stateSinkObject state.Sink) {
+ ep.beforeSave()
+ stateSinkObject.Save(0, &ep.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &ep.waiterQueue)
+ stateSinkObject.Save(2, &ep.cooked)
+ stateSinkObject.Save(3, &ep.ops)
+ stateSinkObject.Save(4, &ep.stats)
+ stateSinkObject.Save(5, &ep.rcvList)
+ stateSinkObject.Save(6, &ep.rcvBufSize)
+ stateSinkObject.Save(7, &ep.rcvClosed)
+ stateSinkObject.Save(8, &ep.rcvDisabled)
+ stateSinkObject.Save(9, &ep.closed)
+ stateSinkObject.Save(10, &ep.boundNetProto)
+ stateSinkObject.Save(11, &ep.boundNIC)
+ stateSinkObject.Save(12, &ep.lastError)
+}
+
+// +checklocksignore
+func (ep *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &ep.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &ep.waiterQueue)
+ stateSourceObject.Load(2, &ep.cooked)
+ stateSourceObject.Load(3, &ep.ops)
+ stateSourceObject.Load(4, &ep.stats)
+ stateSourceObject.Load(5, &ep.rcvList)
+ stateSourceObject.Load(6, &ep.rcvBufSize)
+ stateSourceObject.Load(7, &ep.rcvClosed)
+ stateSourceObject.Load(8, &ep.rcvDisabled)
+ stateSourceObject.Load(9, &ep.closed)
+ stateSourceObject.Load(10, &ep.boundNetProto)
+ stateSourceObject.Load(11, &ep.boundNIC)
+ stateSourceObject.Load(12, &ep.lastError)
+ stateSourceObject.AfterLoad(ep.afterLoad)
+}
+
+func (l *packetList) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.packetList"
+}
+
+func (l *packetList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *packetList) beforeSave() {}
+
+// +checklocksignore
+func (l *packetList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *packetList) afterLoad() {}
+
+// +checklocksignore
+func (l *packetList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *packetEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.packetEntry"
+}
+
+func (e *packetEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *packetEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *packetEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *packetEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *packetEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*packet)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*packetList)(nil))
+ state.Register((*packetEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
deleted file mode 100644
index b7e97e218..000000000
--- a/pkg/tcpip/transport/raw/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "raw_packet_list",
- out = "raw_packet_list.go",
- package = "raw",
- prefix = "rawPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*rawPacket",
- "Linker": "*rawPacket",
- },
-)
-
-go_library(
- name = "raw",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "protocol.go",
- "raw_packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport",
- "//pkg/tcpip/transport/internal/network",
- "//pkg/tcpip/transport/packet",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go
new file mode 100644
index 000000000..48804ff1b
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_packet_list.go
@@ -0,0 +1,221 @@
+package raw
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type rawPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type rawPacketList struct {
+ head *rawPacket
+ tail *rawPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *rawPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *rawPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *rawPacketList) Front() *rawPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *rawPacketList) Back() *rawPacket {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *rawPacketList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (rawPacketElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *rawPacketList) PushFront(e *rawPacket) {
+ linker := rawPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *rawPacketList) PushBack(e *rawPacket) {
+ linker := rawPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *rawPacketList) PushBackList(m *rawPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *rawPacketList) InsertAfter(b, e *rawPacket) {
+ bLinker := rawPacketElementMapper{}.linkerFor(b)
+ eLinker := rawPacketElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ rawPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *rawPacketList) InsertBefore(a, e *rawPacket) {
+ aLinker := rawPacketElementMapper{}.linkerFor(a)
+ eLinker := rawPacketElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ rawPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *rawPacketList) Remove(e *rawPacket) {
+ linker := rawPacketElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ rawPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ rawPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type rawPacketEntry struct {
+ next *rawPacket
+ prev *rawPacket
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) Next() *rawPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) Prev() *rawPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) SetNext(elem *rawPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) SetPrev(elem *rawPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
new file mode 100644
index 000000000..0775faa58
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -0,0 +1,161 @@
+// automatically generated by stateify.
+
+package raw
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *rawPacket) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.rawPacket"
+}
+
+func (p *rawPacket) StateFields() []string {
+ return []string{
+ "rawPacketEntry",
+ "data",
+ "receivedAt",
+ "senderAddr",
+ }
+}
+
+func (p *rawPacket) beforeSave() {}
+
+// +checklocksignore
+func (p *rawPacket) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(1, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(2, receivedAtValue)
+ stateSinkObject.Save(0, &p.rawPacketEntry)
+ stateSinkObject.Save(3, &p.senderAddr)
+}
+
+func (p *rawPacket) afterLoad() {}
+
+// +checklocksignore
+func (p *rawPacket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.rawPacketEntry)
+ stateSourceObject.Load(3, &p.senderAddr)
+ stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(2, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "DefaultSocketOptionsHandler",
+ "transProto",
+ "waiterQueue",
+ "associated",
+ "net",
+ "stats",
+ "ops",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "frozen",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &e.transProto)
+ stateSinkObject.Save(2, &e.waiterQueue)
+ stateSinkObject.Save(3, &e.associated)
+ stateSinkObject.Save(4, &e.net)
+ stateSinkObject.Save(5, &e.stats)
+ stateSinkObject.Save(6, &e.ops)
+ stateSinkObject.Save(7, &e.rcvList)
+ stateSinkObject.Save(8, &e.rcvBufSize)
+ stateSinkObject.Save(9, &e.rcvClosed)
+ stateSinkObject.Save(10, &e.frozen)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &e.transProto)
+ stateSourceObject.Load(2, &e.waiterQueue)
+ stateSourceObject.Load(3, &e.associated)
+ stateSourceObject.Load(4, &e.net)
+ stateSourceObject.Load(5, &e.stats)
+ stateSourceObject.Load(6, &e.ops)
+ stateSourceObject.Load(7, &e.rcvList)
+ stateSourceObject.Load(8, &e.rcvBufSize)
+ stateSourceObject.Load(9, &e.rcvClosed)
+ stateSourceObject.Load(10, &e.frozen)
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (l *rawPacketList) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.rawPacketList"
+}
+
+func (l *rawPacketList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *rawPacketList) beforeSave() {}
+
+// +checklocksignore
+func (l *rawPacketList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *rawPacketList) afterLoad() {}
+
+// +checklocksignore
+func (l *rawPacketList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *rawPacketEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.rawPacketEntry"
+}
+
+func (e *rawPacketEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *rawPacketEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *rawPacketEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *rawPacketEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *rawPacketEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*rawPacket)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*rawPacketList)(nil))
+ state.Register((*rawPacketEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
deleted file mode 100644
index 20958d882..000000000
--- a/pkg/tcpip/transport/tcp/BUILD
+++ /dev/null
@@ -1,132 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "tcp_segment_list",
- out = "tcp_segment_list.go",
- package = "tcp",
- prefix = "segment",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*segment",
- "Linker": "*segment",
- },
-)
-
-go_template_instance(
- name = "tcp_endpoint_list",
- out = "tcp_endpoint_list.go",
- package = "tcp",
- prefix = "endpoint",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*endpoint",
- "Linker": "*endpoint",
- },
-)
-
-go_library(
- name = "tcp",
- srcs = [
- "accept.go",
- "connect.go",
- "connect_unsafe.go",
- "cubic.go",
- "dispatcher.go",
- "endpoint.go",
- "endpoint_state.go",
- "forwarder.go",
- "protocol.go",
- "rack.go",
- "rcv.go",
- "reno.go",
- "reno_recovery.go",
- "sack.go",
- "sack_recovery.go",
- "sack_scoreboard.go",
- "segment.go",
- "segment_heap.go",
- "segment_queue.go",
- "segment_state.go",
- "segment_unsafe.go",
- "snd.go",
- "tcp_endpoint_list.go",
- "tcp_segment_list.go",
- "timer.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/rand",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/hash/jenkins",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/internal/tcp",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/raw",
- "//pkg/waiter",
- "@com_github_google_btree//:go_default_library",
- ],
-)
-
-go_test(
- name = "tcp_x_test",
- size = "large",
- srcs = [
- "dual_stack_test.go",
- "rcv_test.go",
- "sack_scoreboard_test.go",
- "tcp_noracedetector_test.go",
- "tcp_rack_test.go",
- "tcp_sack_test.go",
- "tcp_test.go",
- "tcp_timestamp_test.go",
- ],
- shard_count = more_shards,
- deps = [
- ":tcp",
- "//pkg/rand",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/tcp/testing/context",
- "//pkg/test/testutil",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "tcp_test",
- size = "small",
- srcs = [
- "segment_test.go",
- "timer_test.go",
- ],
- library = ":tcp",
- deps = [
- "//pkg/sleep",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
deleted file mode 100644
index 5342aacfd..000000000
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ /dev/null
@@ -1,650 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-func TestV4MappedConnectOnV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Start connection attempt, it must fail.
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
-}
-
-func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
- // Start connection attempt.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&we)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- synCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ))
- checker.IPv4(t, b, synCheckers...)
-
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- ackCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ))
- checker.IPv4(t, c.GetPacket(), ackCheckers...)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestV4MappedConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped address.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
- // Start connection attempt to IPv6 address.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&we)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetV6Packet()
- synCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ))
- checker.IPv6(t, b, synCheckers...)
-
- tcp := header.TCP(header.IPv6(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- iss := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- ackCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ))
- checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestV6Connect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) {
- c := context.NewWithOpts(t, context.Options{
- EnableV6: true,
- MTU: defaultMTU,
- })
- defer c.Cleanup()
-
- // Create a v6 endpoint but don't set the v6-only TCP option.
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to local address.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV4RefuseOnV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the RST reply.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-}
-
-func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the RST reply.
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-}
-
-func testV4Accept(t *testing.T, c *context.Context) {
- c.SetGSOEnabled(true)
- defer c.SetGSOEnabled(false)
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- checker.IPv4(t, b,
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- nep, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestAddr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
- }
-
- var r strings.Reader
- data := "Don't panic"
- r.Reset(data)
- nep.Write(&r, tcpip.WriteOptions{})
- b = c.GetPacket()
- tcp = header.IPv4(b).Payload()
- if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
- }
-}
-
-func TestV4AcceptOnV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV4AcceptOnBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV6AcceptOnV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetV6Packet()
- tcp := header.TCP(header.IPv6(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- checker.IPv6(t, b,
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-
- // Send ACK.
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
- var addr tcpip.FullAddress
- _, _, err := c.EP.Accept(&addr)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(&addr)
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- if addr.Addr != context.TestV6Addr {
- t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
- }
-}
-
-func TestV4AcceptOnV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func testV4ListenClose(t *testing.T, c *context.Context) {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- const n = 32
-
- // Start listening.
- if err := c.EP.Listen(n); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- irs := seqnum.Value(789)
- for i := uint16(0); i < n; i++ {
- // Send a SYN request.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + i,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- }
-
- // Each of these ACKs will cause a syn-cookie based connection to be
- // accepted and delivered to the listening endpoint.
- for i := 0; i < n; i++ {
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- }
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- nep, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(10 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- nep.Close()
- c.EP.Close()
-}
-
-func TestV4ListenCloseOnV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4ListenClose(t, c)
-}
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
deleted file mode 100644
index e47a07030..000000000
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
-)
-
-func TestAcceptable(t *testing.T) {
- for _, tt := range []struct {
- segSeq seqnum.Value
- segLen seqnum.Size
- rcvNxt, rcvAcc seqnum.Value
- want bool
- }{
- // The segment is smaller than the window.
- {105, 2, 100, 104, false},
- {105, 2, 101, 105, true},
- {105, 2, 102, 106, true},
- {105, 2, 103, 107, true},
- {105, 2, 104, 108, true},
- {105, 2, 105, 109, true},
- {105, 2, 106, 110, true},
- {105, 2, 107, 111, false},
-
- // The segment is larger than the window.
- {105, 4, 103, 105, true},
- {105, 4, 104, 106, true},
- {105, 4, 105, 107, true},
- {105, 4, 106, 108, true},
- {105, 4, 107, 109, true},
- {105, 4, 108, 110, true},
- {105, 4, 109, 111, false},
- {105, 4, 110, 112, false},
-
- // The segment has no width.
- {105, 0, 100, 102, false},
- {105, 0, 101, 103, false},
- {105, 0, 102, 104, false},
- {105, 0, 103, 105, true},
- {105, 0, 104, 106, true},
- {105, 0, 105, 107, true},
- {105, 0, 106, 108, false},
- {105, 0, 107, 109, false},
-
- // The receive window has no width.
- {105, 2, 103, 103, false},
- {105, 2, 104, 104, false},
- {105, 2, 105, 105, false},
- {105, 2, 106, 106, false},
- {105, 2, 107, 107, false},
- {105, 2, 108, 108, false},
- {105, 2, 109, 109, false},
- } {
- if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
- t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
- }
- }
-}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
deleted file mode 100644
index b4e5ba0df..000000000
--- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
+++ /dev/null
@@ -1,249 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
-)
-
-const smss = 1500
-
-func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard {
- s := tcp.NewSACKScoreboard(smss, iss)
- for _, blk := range blocks {
- s.Insert(blk)
- }
- return s
-}
-
-func TestSACKScoreboardIsSACKED(t *testing.T) {
- type blockTest struct {
- block header.SACKBlock
- sacked bool
- }
- testCases := []struct {
- comment string
- scoreboardBlocks []header.SACKBlock
- blockTests []blockTest
- iss seqnum.Value
- }{
- {
- "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks",
- []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}},
- []blockTest{
- {header.SACKBlock{15, 21}, true},
- {header.SACKBlock{200, 201}, false},
- {header.SACKBlock{50, 51}, false},
- {header.SACKBlock{53, 120}, true},
- },
- 0,
- },
- {
- "Test disjoint SACKBlocks",
- []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}},
- []blockTest{
- {header.SACKBlock{2288624809, 2288810057}, true},
- {header.SACKBlock{2288811477, 2288838565}, true},
- {header.SACKBlock{2288810057, 2288811477}, false},
- },
- 2288624809,
- },
- {
- "Test sequence number wrap around",
- []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}},
- []blockTest{
- {header.SACKBlock{4294254144, 4294254145}, true},
- {header.SACKBlock{4294254143, 4294254144}, false},
- {header.SACKBlock{4294254144, 1}, true},
- {header.SACKBlock{225652, 5350509}, false},
- {header.SACKBlock{5340409, 5350509}, true},
- {header.SACKBlock{5350509, 5350609}, false},
- },
- 4294254144,
- },
- {
- "Test disjoint SACKBlocks out of order",
- []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}},
- []blockTest{
- {header.SACKBlock{827426028, 827428867}, true},
- {header.SACKBlock{827450168, 827450275}, false},
- },
- 827426000,
- },
- }
- for _, tc := range testCases {
- sb := initScoreboard(tc.scoreboardBlocks, tc.iss)
- for _, blkTest := range tc.blockTests {
- if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want {
- t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want)
- }
- }
- }
-}
-
-func TestSACKScoreboardIsRangeLost(t *testing.T) {
- s := tcp.NewSACKScoreboard(10, 0)
- s.Insert(header.SACKBlock{1, 25})
- s.Insert(header.SACKBlock{25, 50})
- s.Insert(header.SACKBlock{51, 100})
- s.Insert(header.SACKBlock{111, 120})
- s.Insert(header.SACKBlock{101, 110})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{145, 146})
- s.Insert(header.SACKBlock{147, 148})
- s.Insert(header.SACKBlock{149, 150})
- s.Insert(header.SACKBlock{153, 154})
- s.Insert(header.SACKBlock{155, 156})
- testCases := []struct {
- block header.SACKBlock
- lost bool
- }{
- // Block not covered by SACK block and has more than
- // nDupAckThreshold discontiguous SACK blocks after it as well
- // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
- // SACKED above the sequence number covered by this block.
- {block: header.SACKBlock{0, 1}, lost: true},
-
- // These blocks have all been SACKed and should not be
- // considered lost.
- {block: header.SACKBlock{1, 2}, lost: false},
- {block: header.SACKBlock{25, 26}, lost: false},
- {block: header.SACKBlock{1, 45}, lost: false},
-
- // Same as the first case above.
- {block: header.SACKBlock{50, 51}, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {block: header.SACKBlock{119, 120}, lost: false},
-
- // This one should return true because there are >
- // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
- // sacked above this sequence number.
- {block: header.SACKBlock{120, 121}, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {block: header.SACKBlock{125, 126}, lost: false},
-
- // This block has not been SACKed and there are nDupAckThreshold
- // number of SACKed blocks after it.
- {block: header.SACKBlock{141, 145}, lost: true},
-
- // This block has not been SACKed and there are less than
- // nDupAckThreshold SACKed sequences after it.
- {block: header.SACKBlock{151, 152}, lost: false},
- }
- for _, tc := range testCases {
- if want, got := tc.lost, s.IsRangeLost(tc.block); got != want {
- t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want)
- }
- }
-}
-
-func TestSACKScoreboardIsLost(t *testing.T) {
- s := tcp.NewSACKScoreboard(10, 0)
- s.Insert(header.SACKBlock{1, 25})
- s.Insert(header.SACKBlock{25, 50})
- s.Insert(header.SACKBlock{51, 100})
- s.Insert(header.SACKBlock{111, 120})
- s.Insert(header.SACKBlock{101, 110})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{145, 146})
- s.Insert(header.SACKBlock{147, 148})
- s.Insert(header.SACKBlock{149, 150})
- s.Insert(header.SACKBlock{153, 154})
- s.Insert(header.SACKBlock{155, 156})
- testCases := []struct {
- seq seqnum.Value
- lost bool
- }{
- // Sequence number not covered by SACK block and has more than
- // nDupAckThreshold discontiguous SACK blocks after it as well
- // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
- // SACKED above the sequence number.
- {seq: 0, lost: true},
-
- // These sequence numbers have all been SACKed and should not be
- // considered lost.
- {seq: 1, lost: false},
- {seq: 25, lost: false},
- {seq: 45, lost: false},
-
- // Same as first case above.
- {seq: 50, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {seq: 119, lost: false},
-
- // This one should return true because there are >
- // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
- // sacked above this sequence number.
- {seq: 120, lost: true},
-
- // This sequence number has been SACKed and should not be
- // considered lost.
- {seq: 125, lost: false},
-
- // This sequence number has not been SACKed and there are
- // nDupAckThreshold number of SACKed blocks after it.
- {seq: 141, lost: true},
-
- // This sequence number has not been SACKed and there are less
- // than nDupAckThreshold SACKed sequences after it.
- {seq: 151, lost: false},
- }
- for _, tc := range testCases {
- if want, got := tc.lost, s.IsLost(tc.seq); got != want {
- t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want)
- }
- }
-}
-
-func TestSACKScoreboardDelete(t *testing.T) {
- blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}
- s := initScoreboard(blocks, 4294254143)
- s.Delete(5340408)
- if s.Empty() {
- t.Fatalf("s.Empty() = true, want false")
- }
- if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want {
- t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
- }
- s.Delete(5340410)
- if s.Empty() {
- t.Fatal("s.Empty() = true, want false")
- }
- newSB := header.SACKBlock{5340410, 5350509}
- if !s.IsSACKED(newSB) {
- t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s)
- }
- s.Delete(5350509)
- lastOctet := header.SACKBlock{5350508, 5350509}
- if s.IsSACKED(lastOctet) {
- t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet)
- }
-
- s.Delete(5350510)
- if !s.Empty() {
- t.Fatal("s.Empty() = false, want true")
- }
- if got, want := s.Sacked(), seqnum.Size(0); got != want {
- t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
deleted file mode 100644
index 2d5fdda19..000000000
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp
-
-import (
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type segmentSizeWants struct {
- DataSize int
- SegMemSize int
-}
-
-func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeWants) {
- t.Helper()
- got := segmentSizeWants{
- DataSize: seg.data.Size(),
- SegMemSize: seg.segMemSize(),
- }
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("%s differs (-want +got):\n%s", name, diff)
- }
-}
-
-func TestSegmentMerge(t *testing.T) {
- var clock faketime.NullClock
- id := stack.TransportEndpointID{}
- seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10))
- defer seg1.decRef()
- seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20))
- defer seg2.decRef()
-
- checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
- DataSize: 10,
- SegMemSize: SegSize + 10,
- })
- checkSegmentSize(t, "seg2", seg2, segmentSizeWants{
- DataSize: 20,
- SegMemSize: SegSize + 20,
- })
-
- seg1.merge(seg2)
-
- checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
- DataSize: 30,
- SegMemSize: SegSize + 30,
- })
- checkSegmentSize(t, "seg2", seg2, segmentSizeWants{
- DataSize: 0,
- SegMemSize: SegSize,
- })
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
new file mode 100644
index 000000000..a7dc5df81
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
@@ -0,0 +1,221 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type endpointElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type endpointList struct {
+ head *endpoint
+ tail *endpoint
+}
+
+// Reset resets list l to the empty state.
+func (l *endpointList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *endpointList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *endpointList) Front() *endpoint {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *endpointList) Back() *endpoint {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *endpointList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (endpointElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *endpointList) PushFront(e *endpoint) {
+ linker := endpointElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ endpointElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *endpointList) PushBack(e *endpoint) {
+ linker := endpointElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *endpointList) PushBackList(m *endpointList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *endpointList) InsertAfter(b, e *endpoint) {
+ bLinker := endpointElementMapper{}.linkerFor(b)
+ eLinker := endpointElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ endpointElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *endpointList) InsertBefore(a, e *endpoint) {
+ aLinker := endpointElementMapper{}.linkerFor(a)
+ eLinker := endpointElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ endpointElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *endpointList) Remove(e *endpoint) {
+ linker := endpointElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ endpointElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ endpointElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type endpointEntry struct {
+ next *endpoint
+ prev *endpoint
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) Next() *endpoint {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) Prev() *endpoint {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) SetNext(elem *endpoint) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) SetPrev(elem *endpoint) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
deleted file mode 100644
index 84fb1c416..000000000
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ /dev/null
@@ -1,559 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// These tests are flaky when run under the go race detector due to some
-// iterations taking long enough that the retransmit timer can kick in causing
-// the congestion window measurements to fail due to extra packets etc.
-//
-//go:build !race
-// +build !race
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "math"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
-)
-
-func TestFastRecovery(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Send 3 duplicate acks. This should force an immediate retransmit of
- // the pending packet and put the sender into fast recovery.
- rtxOffset := bytesRead - maxPayload*expected
- for i := 0; i < 3; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Wait before checking metrics.
- metricPollFn := func() error {
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
- }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Now send 7 mode duplicate acks. Each of these should cause a window
- // inflation by 1 and cause the sender to send an extra packet.
- for i := 0; i < 7; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- recover := bytesRead
-
- // Ensure no new packets arrive.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the retransmit due to partial ack.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Wait before checking metrics.
- metricPollFn = func() error {
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
- }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Receive the 10 extra packets that should have been released due to
- // the congestion window inflation in recovery.
- for i := 0; i < 10; i++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // A partial ACK during recovery should reduce congestion window by the
- // number acked. Since we had "expected" packets outstanding before sending
- // partial ack and we acked expected/2 , the cwnd and outstanding should
- // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered
- // fast recovery). Which means the sender should not send any more packets
- // till we ack this one.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge all pending data to recover point.
- c.SendAck(790, recover)
-
- // At this point, the cwnd should reset to expected/2 and there are 10
- // packets outstanding.
- //
- // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
- // the same segment that takes us out of recovery. But because of that
- // the actual cwnd at exit of recovery will be expected/2 + 1 as we
- // acked a cwnd worth of packets which will increase the cwnd further by
- // 1 in congestion avoidance.
- //
- // Now in the first iteration since there are 10 packets outstanding.
- // We would expect to get expected/2 +1 - 10 packets. But subsequent
- // iterations will send us expected/2 + 1 + 1 (per iteration).
- expected = expected/2 + 1 - 10
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- if i == 0 {
- // After the first iteration we expect to get the full
- // congestion window worth of packets in every
- // iteration.
- expected += 10
- }
- expected++
- }
-}
-
-func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // Double the number of expected packets for the next iteration.
- expected *= 2
- }
-}
-
-func TestCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd/2.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected/2.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 (which "consumes" expected/2-1 of the
- // acknowledgements), then the congestion avoidance part will consume
- // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack
- // remains in the "ack count" (which will cause cwnd to be incremented
- // once it reaches cwnd acks).
- //
- // So we're straight into congestion avoidance with cwnd set to
- // expected/2 + 1.
- //
- // Check that packets trains of cwnd packets are sent, and that cwnd is
- // incremented by 1 after we acknowledge each packet.
- expected = expected/2 + 1
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- expected++
- }
-}
-
-// cubicCwnd returns an estimate of a cubic window given the
-// originalCwnd, wMax, last congestion event time and sRTT.
-func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
- cwnd := float64(origCwnd)
- // We wait 50ms between each iteration so sRTT as computed by cubic
- // should be close to 50ms.
- elapsed := (time.Since(congEventTime) + sRTT).Seconds()
- k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
- wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
- cwnd += (wtRTT - cwnd) / cwnd
- return int(cwnd)
-}
-
-func TestCubicCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- enableCUBIC(t, c)
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd * 0.7.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all pending data.
- c.SendAck(790, bytesRead)
-
- // Store away the time we sent the ACK and assuming a 200ms RTO
- // we estimate that the sender will have an RTO 200ms from now
- // and go back into slow start.
- packetDropTime := time.Now().Add(200 * time.Millisecond)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 essentially putting the connection
- // straight into congestion avoidance.
- wMax := expected
- // Lower expected as per cubic spec after a congestion event.
- expected = int(float64(expected) * 0.7)
- cwnd := expected
- for i := 0; i < iterations; i++ {
- // Cubic grows window independent of ACKs. Cubic Window growth
- // is a function of time elapsed since last congestion event.
- // As a result the congestion window does not grow
- // deterministically in response to ACKs.
- //
- // We need to roughly estimate what the cwnd of the sender is
- // based on when we sent the dupacks.
- cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
-
- packetsExpected := cwnd
- for j := 0; j < packetsExpected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
- t.Logf("expected packets received, next trying to receive any extra packets that may come")
-
- // If our estimate was correct there should be no more pending packets.
- // We attempt to read a packet a few times with a short sleep in between
- // to ensure that we don't see the sender send any unexpected packets.
- unexpectedPackets := 0
- for {
- gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
- if !gotPacket {
- break
- }
- bytesRead += maxPayload
- unexpectedPackets++
- time.Sleep(1 * time.Millisecond)
- }
- if unexpectedPackets != 0 {
- t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i)
- }
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
- }
-}
-
-func TestRetransmit(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in two shots. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data[:len(data)/2])
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- r.Reset(data[len(data)/2:])
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Wait for a timeout and retransmit.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- metricPollFn := func() error {
- if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
- }
-
- return nil
- }
-
- // Poll when checking metrics.
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the remaining data, making sure that acknowledged data is not
- // retransmitted.
- for offset := rtxOffset; offset < len(data); offset += maxPayload {
- c.ReceiveAndCheckPacket(data, offset, maxPayload)
- c.SendAck(790, offset+maxPayload)
- }
-
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
deleted file mode 100644
index c35db7c95..000000000
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ /dev/null
@@ -1,1101 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
-)
-
-const (
- maxPayload = 10
- tsOptionSize = 12
- maxTCPOptionSize = 40
- mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
-)
-
-func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) {
- t.Helper()
- opt := tcpip.TCPRecovery(recovery)
- if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err)
- }
-}
-
-// TestRACKUpdate tests the RACK related fields are updated when an ACK is
-// received on a SACK enabled connection.
-func TestRACKUpdate(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- var xmitTime tcpip.MonotonicTime
- probeDone := make(chan struct{})
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint Sender.RACKState is what we expect.
- if state.Sender.RACKState.XmitTime.Before(xmitTime) {
- t.Fatalf("RACK transmit time failed to update when an ACK is received")
- }
-
- gotSeq := state.Sender.RACKState.EndSequence
- wantSeq := state.Sender.SndNxt
- if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
- t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq)
- }
-
- if state.Sender.RACKState.RTT == 0 {
- t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
- }
- close(probeDone)
- })
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
-
- data := make([]byte, maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- xmitTime = c.Stack().Clock().NowMonotonic()
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- <-probeDone
-}
-
-// TestRACKDetectReorder tests that RACK detects packet reordering.
-func TestRACKDetectReorder(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- t.Skipf("Skipping this test as reorder detection does not consider DSACK.")
-
- var n int
- const ackNumToVerify = 2
- probeDone := make(chan struct{})
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- gotSeq := state.Sender.RACKState.FACK
- wantSeq := state.Sender.SndNxt
- // FACK should be updated to the highest ending sequence number of the
- // segment acknowledged most recently.
- if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
- t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq)
- }
-
- n++
- if n < ackNumToVerify {
- if state.Sender.RACKState.Reord {
- t.Fatalf("RACK reorder detected when there is no reordering")
- }
- return
- }
-
- if state.Sender.RACKState.Reord == false {
- t.Fatalf("RACK reorder detection failed")
- }
- close(probeDone)
- })
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
- data := make([]byte, ackNumToVerify*maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- for i := 0; i < ackNumToVerify; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- start := c.IRS.Add(maxPayload + 1)
- end := start.Add(maxPayload)
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
- c.SendAck(seq, bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- <-probeDone
-}
-
-func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte {
- setStackSACKPermitted(t, c, true)
- if !enableRACK {
- setStackTCPRecovery(t, c, 0)
- }
- // The delay should be below initial RTO (1s) otherwise retransimission
- // will start. Choose a relatively large value so that estimated RTT
- // keeps high even after a few rounds of undelayed RTT samples.
- c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */)
-
- data := make([]byte, numPackets*maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- for i := 0; i < numPackets; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- return data
-}
-
-const (
- validDSACKDetected = 1
- failedToDetectDSACK = 2
- invalidDSACKDetected = 3
-)
-
-func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) {
- var n int
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that RACK detects DSACK.
- n++
- if n < numACK {
- if state.Sender.RACKState.DSACKSeen {
- probeDone <- invalidDSACKDetected
- }
- return
- }
-
- if !state.Sender.RACKState.DSACKSeen {
- probeDone <- failedToDetectDSACK
- return
- }
- probeDone <- validDSACKDetected
- })
-}
-
-// TestRACKTLPRecovery tests that RACK sends a tail loss probe (TLP) in the
-// case of a tail loss. This simulates a situation where the TLP is able to
-// insinuate the SACK holes and sender is able to retransmit the rest.
-func TestRACKTLPRecovery(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [6-8] are lost. Send cumulative ACK for [1-5].
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // PTO should fire and send #8 packet as a TLP.
- c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize)
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- // Send the SACK after RTT because RACK RFC states that if the ACK for a
- // retransmission arrives before the smoothed RTT then the sender should not
- // update RACK state as it could be a spurious inference.
- time.Sleep(info.RTT)
-
- // Okay, let the sender know we got #8 using a SACK block.
- eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload))
- eighthPEnd := eighthPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}})
-
- // The sender should be entering RACK based loss-recovery and sending #6 and
- // #7 one after another.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += 2 * maxPayload
- c.SendAck(seq, bytesRead)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // One fast retransmit after the SACK.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- // Recovery should be SACK recovery.
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- // Packets 6, 7 and 8 were retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 3},
- // TLP recovery should have been detected.
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 1},
- // No RTOs should have occurred.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKTLPFallbackRTO tests that RACK sends a tail loss probe (TLP) in the
-// case of a tail loss. This simulates a situation where either the TLP or its
-// ACK is lost. The sender should retransmit when RTO fires.
-func TestRACKTLPFallbackRTO(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [6-8] are lost. Send cumulative ACK for [1-5].
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // PTO should fire and send #8 packet as a TLP.
- c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize)
-
- // Either the TLP or the ACK the receiver sent with SACK blocks was lost.
-
- // Confirm that RTO fires and retransmits packet #6.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // No fast retransmits happened.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- // No SACK recovery happened.
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0},
- // TLP was unsuccessful.
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestNoTLPRecoveryOnDSACK tests the scenario where the sender speculates a
-// tail loss and sends a TLP. Everything is received and acked. The probe
-// segment is DSACKed. No fast recovery should be triggered in this case.
-func TestNoTLPRecoveryOnDSACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [1-5] are received first. [6-8] took a detour and will take a
- // while to arrive. Ack [1-5].
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // The tail loss probe (#8 packet) is received.
- c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize)
-
- // Now that all 8 packets are received + duplicate 8th packet, send ack.
- bytesRead += 3 * maxPayload
- eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload))
- eighthPEnd := eighthPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}})
-
- // Wait for RTO and make sure that nothing else is received.
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
- if p := c.GetPacketWithTimeout(info.RTO); p != nil {
- t.Errorf("received an unexpected packet: %v", p)
- }
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Make sure no recovery was entered.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0},
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should not have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- // Only #8 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestNoTLPOnSACK tests the scenario where there is not exactly a tail loss
-// due to the presence of multiple SACK holes. In such a scenario, TLP should
-// not be sent.
-func TestNoTLPOnSACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [1-5] and #7 were received. #6 and #8 were dropped.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- seventhStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- seventhEnd := seventhStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhStart, seventhEnd}})
-
- // The sender should retransmit #6. If the sender sends a TLP, then #8 will
- // received and fail this test.
- c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // #6 was retransmitted due to SACK recovery.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should not have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- // Only #6 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKOnePacketTailLoss tests the trivial case of a tail loss of only one
-// packet. The probe should itself repairs the loss instead of having to go
-// into any recovery.
-func TestRACKOnePacketTailLoss(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 3 packets.
- numPackets := 3
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [1-2] are received. #3 is lost.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 2 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // PTO should fire and send #3 packet as a TLP.
- c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- c.SendAck(seq, bytesRead)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // #3 was retransmitted as TLP.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should not have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- // Only #3 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments.
-// See: https://tools.ietf.org/html/rfc2883#section-4.1.1.
-func TestRACKDetectDSACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 2
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Cumulative ACK for [1-5] packets and SACK #8 packet (to prevent TLP).
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload))
- eighthPEnd := eighthPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}})
-
- // Expect retransmission of #6 packet after RTO expires.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Send DSACK block for #6 packet indicating both
- // initial and retransmitted packet are received and
- // packets [1-8] are received.
- start := c.IRS.Add(1 + seqnum.Size(bytesRead))
- end := start.Add(maxPayload)
- bytesRead += 3 * maxPayload
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Wait for the probe function to finish processing the
- // ACK before the test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Check DSACK was received for one segment.
- {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of
-// order segments.
-// See: https://tools.ietf.org/html/rfc2883#section-4.1.2.
-func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 2
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 10
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP).
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- seventhPEnd := seventhPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}})
-
- // Expect retransmission of #6 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Send DSACK block for #6 packet indicating both
- // initial and retransmitted packet are received and
- // packets [1-7] are received.
- start := c.IRS.Add(1 + seqnum.Size(bytesRead))
- end := start.Add(maxPayload)
- bytesRead += 2 * maxPayload
- // Send DSACK block for #6 along with SACK for out of
- // order #9 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload)
- end1 := start1.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}})
-
- // Wait for the probe function to finish processing the
- // ACK before the test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a
-// duplicate of out of order packet.
-// See: https://tools.ietf.org/html/rfc2883#section-4.1.3
-func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 4
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 10
- sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // ACK [1-5] packets.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // Send SACK indicating #6 packet is missing and received #7 packet.
- offset := seqnum.Size(bytesRead + maxPayload)
- start := c.IRS.Add(1 + offset)
- end := start.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Send SACK with #6 packet is missing and received [7-8] packets.
- end = start.Add(2 * maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Consider #8 packet is duplicated on the network and send DSACK.
- dsackStart := c.IRS.Add(1 + offset + maxPayload)
- dsackEnd := dsackStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment.
-// See: https://tools.ietf.org/html/rfc2883#section-4.2.1.
-func TestRACKDetectDSACKSingleDup(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 4
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 4
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-3] packets and received #4 packet.
- seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // ACK for retransmitted #2 packet.
- bytesRead += maxPayload
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by
- // sending DSACK block for the subsegment.
- dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
- dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
- c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Check DSACK was received for a subsegment.
- {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous
-// duplicate subsegments covered by the cumulative acknowledgement.
-// See: https://tools.ietf.org/html/rfc2883#section-4.2.2.
-func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 5
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 6
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-5] packets and received #6 packet.
- seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(5*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Received delayed #2 packet.
- bytesRead += maxPayload
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Received delayed #4 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
- end1 := start1.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
-
- // Simulate receiving retransmitted subsegment for #2 packet and delayed #3
- // packet by sending DSACK block for #2 packet.
- dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload))
- dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
- c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not
-// covered by the cumulative acknowledgement.
-// See: https://tools.ietf.org/html/rfc2883#section-4.2.3.
-func TestRACKDetectDSACKDup(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 5
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 7
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-6] packets and SACK #7 packet.
- seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Received delayed #3 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
- end1 := start1.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Consider #2 packet has been dropped and SACK #4 packet.
- start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
- end2 := start2.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}})
-
- // Simulate receiving retransmitted subsegment for #3 packet and delayed #5
- // packet by sending DSACK block for the subsegment.
- dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
- dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
- end1 = end1.Add(seqnum.Size(2 * maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK
-// is not the first SACK block.
-func TestRACKWithInvalidDSACKBlock(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan struct{})
- const ackNumToVerify = 2
- var n int
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that RACK does not detect DSACK when DSACK block is
- // not the first SACK block.
- n++
- t.Helper()
- if state.Sender.RACKState.DSACKSeen {
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-
- if n == ackNumToVerify {
- close(probeDone)
- }
- })
-
- numPackets := 10
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP).
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- seventhPEnd := seventhPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}})
-
- // Expect retransmission of #6 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Send DSACK block for #6 packet indicating both
- // initial and retransmitted packet are received and
- // packets [1-7] are received.
- start := c.IRS.Add(1 + seqnum.Size(bytesRead))
- end := start.Add(maxPayload)
- bytesRead += 2 * maxPayload
-
- // Send DSACK block as second block. The first block is a SACK for #9 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload)
- end1 := start1.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
-
- // Wait for the probe function to finish processing the
- // ACK before the test completes.
- <-probeDone
-}
-
-func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) {
- var n int
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that RACK detects DSACK.
- n++
- if n < numACK {
- return
- }
-
- if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT {
- probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT)
- return
- }
-
- if state.Sender.RACKState.ReoWndIncr != 1 {
- probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr)
- return
- }
-
- if state.Sender.RACKState.ReoWndPersist > 0 {
- probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist)
- return
- }
- probeDone <- nil
- })
-}
-
-func TestRACKCheckReorderWindow(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan error)
- const ackNumToVerify = 3
- addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone)
-
- const numPackets = 7
- sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-6] packets and SACK #7 packet.
- start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Received delayed packets [2-6] which indicates there is reordering
- // in the connection.
- bytesRead += 6 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- if err := <-probeDone; err != nil {
- t.Fatalf("unexpected values for RACK variables: %v", err)
- }
-}
-
-func TestRACKWithDuplicateACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- const numPackets = 4
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send three duplicate ACKs to trigger fast recovery. The first
- // segment is considered as lost and will be retransmitted after
- // receiving the duplicate ACKs.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- for i := 0; i < 3; i++ {
- c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
- end = end.Add(seqnum.Size(maxPayload))
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK
-// is received.
-func TestRACKUpdateSackedOut(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan struct{})
- ackNum := 0
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint Sender.SackedOut is what we expect.
- if state.Sender.SackedOut != 2 && ackNum == 0 {
- t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
- }
-
- if state.Sender.SackedOut != 0 && ackNum == 1 {
- t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
- }
- if ackNum > 0 {
- close(probeDone)
- }
- ackNum++
- })
-
- sendAndReceiveWithSACK(t, c, 8, true /* enableRACK */)
-
- // ACK for [3-5] packets.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
- bytesRead := 2 * maxPayload
- end := start.Add(seqnum.Size(bytesRead))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- bytesRead += 3 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- <-probeDone
-}
-
-// TestRACKWithWindowFull tests that RACK honors the receive window size.
-func TestRACKWithWindowFull(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
-
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- const numPkts = 10
- data := make([]byte, numPkts*maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- for i := 0; i < numPkts; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- if i == 0 {
- // Send ACK for the first packet to establish RTT.
- c.SendAck(seq, maxPayload)
- }
- }
-
- // SACK for #10 packet.
- start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}})
-
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
- // Wait for RTT to trigger recovery.
- time.Sleep(info.RTT)
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize)
-
- // Send ACK for #2 packet.
- c.SendAck(seq, 3*maxPayload)
-
- // Expect retransmission of #3 packet.
- c.ReceiveAndCheckPacketWithOptions(data, 3*maxPayload, maxPayload, tsOptionSize)
-
- // Send ACK with zero window size.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seq,
- AckNum: c.IRS.Add(1 + 4*maxPayload),
- RcvWnd: 0,
- })
-
- // No packet should be received as the receive window size is zero.
- c.CheckNoPacket("unexpected packet received after userTimeout has expired")
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
deleted file mode 100644
index 6255355bb..000000000
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ /dev/null
@@ -1,704 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "log"
- "reflect"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
-)
-
-// createConnectedWithSACKPermittedOption creates and connects c.ep with the
-// SACKPermitted option enabled if the stack in the context has the SACK support
-// enabled.
-func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
-}
-
-// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
-// option enabled if the stack in the context has SACK and TS enabled.
-func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
-}
-
-func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
- t.Helper()
- opt := tcpip.TCPSACKEnabled(enable)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-}
-
-// TestSackPermittedConnect establishes a connection with the SACK option
-// enabled.
-func TestSackPermittedConnect(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
- rep := createConnectedWithSACKPermittedOption(c)
- data := []byte{1, 2, 3}
-
- rep.SendPacket(data, nil)
- savedSeqNum := rep.NextSeqNum
- rep.VerifyACKNoSACK()
-
- // Make an out of order packet and send it.
- rep.NextSeqNum += 3
- sackBlocks := []header.SACKBlock{
- {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
- }
- rep.SendPacket(data, nil)
-
- // Restore the saved sequence number so that the
- // VerifyXXX calls use the right sequence number for
- // checking ACK numbers.
- rep.NextSeqNum = savedSeqNum
- if sackEnabled {
- rep.VerifyACKHasSACK(sackBlocks)
- } else {
- rep.VerifyACKNoSACK()
- }
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative ACK for all 9
- // bytes sent and no SACK blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
-}
-
-// TestSackDisabledConnect establishes a connection with the SACK option
-// disabled and verifies that no SACKs are sent for out of order segments.
-func TestSackDisabledConnect(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
-
- rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
-
- data := []byte{1, 2, 3}
-
- rep.SendPacket(data, nil)
- savedSeqNum := rep.NextSeqNum
- rep.VerifyACKNoSACK()
-
- // Make an out of order packet and send it.
- rep.NextSeqNum += 3
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older sequence number and
- // no SACK blocks.
- rep.NextSeqNum = savedSeqNum
- rep.VerifyACKNoSACK()
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative ACK for all 9
- // bytes sent and no SACK blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
-}
-
-// TestSackPermittedAccept accepts and establishes a connection with the
-// SACKPermitted option enabled if the connection request specifies the
-// SACKPermitted option. In case of SYN cookies SACK should be disabled as we
-// don't encode the SACK information in the cookie.
-func TestSackPermittedAccept(t *testing.T) {
- type testCase struct {
- cookieEnabled bool
- sackPermitted bool
- wndScale int
- wndSize uint16
- }
-
- testCases := []testCase{
- // When cookie is used window scaling is disabled.
- {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if tc.cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
-
- rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
- // Now verify no SACK blocks are
- // received when sack is disabled.
- data := []byte{1, 2, 3}
- rep.SendPacket(data, nil)
- rep.VerifyACKNoSACK()
-
- savedSeqNum := rep.NextSeqNum
-
- // Make an out of order packet and send
- // it.
- rep.NextSeqNum += 3
- sackBlocks := []header.SACKBlock{
- {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
- }
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older
- // sequence number.
- rep.NextSeqNum = savedSeqNum
- if sackEnabled && tc.sackPermitted {
- rep.VerifyACKHasSACK(sackBlocks)
- } else {
- rep.VerifyACKNoSACK()
- }
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative
- // ACK for all 9 bytes sent and no SACK
- // blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned
- // in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
- })
- }
-}
-
-// TestSackDisabledAccept accepts and establishes a connection with
-// the SACKPermitted option disabled and verifies that no SACKs are
-// sent for out of order packets.
-func TestSackDisabledAccept(t *testing.T) {
- type testCase struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }
-
- testCases := []testCase{
- // When cookie is used window scaling is disabled.
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if tc.cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
-
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
-
- rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Now verify no SACK blocks are
- // received when sack is disabled.
- data := []byte{1, 2, 3}
- rep.SendPacket(data, nil)
- rep.VerifyACKNoSACK()
- savedSeqNum := rep.NextSeqNum
-
- // Make an out of order packet and send
- // it.
- rep.NextSeqNum += 3
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older
- // sequence number and no SACK blocks.
- rep.NextSeqNum = savedSeqNum
- rep.VerifyACKNoSACK()
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative
- // ACK for all 9 bytes sent and no SACK
- // blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned
- // in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
- })
- }
-}
-
-func TestUpdateSACKBlocks(t *testing.T) {
- testCases := []struct {
- segStart seqnum.Value
- segEnd seqnum.Value
- rcvNxt seqnum.Value
- sackBlocks []header.SACKBlock
- updated []header.SACKBlock
- }{
- // Trivial cases where current SACK block list is empty and we
- // have an out of order delivery.
- {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
- {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
- {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
-
- // Cases where current SACK block list is not empty and we have
- // an out of order delivery. Tests that the updated SACK block
- // list has the first block as the one that contains the new
- // SACK block representing the segment that was just delivered.
- {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
- {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
- {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
-
- // Ensure that we only retain header.MaxSACKBlocks and drop the
- // oldest one if adding a new block exceeds
- // header.MaxSACKBlocks.
- {24, 30, 9,
- []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
- []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
-
- // Cases where segment extends an existing SACK block.
- {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
- {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
- {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
- {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
- {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
- {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
- {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
- {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
-
- // Cases where segment contains rcvNxt.
- {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
- }
-
- for _, tc := range testCases {
- var sack tcp.SACKInfo
- copy(sack.Blocks[:], tc.sackBlocks)
- sack.NumBlocks = len(tc.sackBlocks)
- tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
- if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) {
- t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
- }
-
- }
-}
-
-func TestTrimSackBlockList(t *testing.T) {
- testCases := []struct {
- rcvNxt seqnum.Value
- sackBlocks []header.SACKBlock
- trimmed []header.SACKBlock
- }{
- // Simple cases where we trim whole entries.
- {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
- {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
- {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
- {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
- // Cases where we need to update a block.
- {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
- {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
- {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
- {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
- }
- for _, tc := range testCases {
- var sack tcp.SACKInfo
- copy(sack.Blocks[:], tc.sackBlocks)
- sack.NumBlocks = len(tc.sackBlocks)
- tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
- if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) {
- t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
- }
- }
-}
-
-func TestSACKRecovery(t *testing.T) {
- const maxPayload = 10
- // See: tcp.makeOptions for why tsOptionSize is set to 12 here.
- const tsOptionSize = 12
- // Enabling SACK means the payload size is reduced to account
- // for the extra space required for the TCP options.
- //
- // We increase the MTU by 40 bytes to account for SACK and Timestamp
- // options.
- const maxTCPOptionSize = 40
-
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
- defer c.Cleanup()
-
- c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
- // We use log.Printf instead of t.Logf here because this probe
- // can fire even when the test function has finished. This is
- // because closing the endpoint in cleanup() does not mean the
- // actual worker loop terminates immediately as it still has to
- // do a full TCP shutdown. But this test can finish running
- // before the shutdown is done. Using t.Logf in such a case
- // causes the test to panic due to logging after test finished.
- log.Printf("state: %+v\n", s)
- })
- setStackSACKPermitted(t, c, true)
- setStackTCPRecovery(t, c, 0)
- createConnectedWithSACKAndTS(c)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(seq, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Send 3 duplicate acks. This should force an immediate retransmit of
- // the pending packet and put the sender into fast recovery.
- rtxOffset := bytesRead - maxPayload*expected
- start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
- end := start.Add(10)
- for i := 0; i < 3; i++ {
- c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
- end = end.Add(10)
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
- // window inflation and sending of packets is completely handled by the
- // SACK Recovery algorithm. We should see no packets being released, as
- // the cwnd at this point after entering recovery should be half of the
- // outstanding number of packets in flight.
- for i := 0; i < 7; i++ {
- c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
- end = end.Add(10)
- }
-
- recover := bytesRead
-
- // Ensure no new packets arrive.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge half of the pending data. This along with the 10 sacked
- // segments above should reduce the outstanding below the current
- // congestion window allowing the sender to transmit data.
- rtxOffset = bytesRead - expected*maxPayload/2
-
- // Now send a partial ACK w/ a SACK block that indicates that the next 3
- // segments are lost and we have received 6 segments after the lost
- // segments. This should cause the sender to immediately transmit all 3
- // segments in response to this ACK unlike in FastRecovery where only 1
- // segment is retransmitted per ACK.
- start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
- end = start.Add(60)
- c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
-
- // At this point, we acked expected/2 packets and we SACKED 6 packets and
- // 3 segments were considered lost due to the SACK block we sent.
- //
- // So total packets outstanding can be calculated as follows after 7
- // iterations of slow start -> 10/20/40/80/160/320/640. So expected
- // should be 640 at start, then we went to recover at which point the
- // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the
- // network).
- // Outstanding at this point after acking half the window
- // (320 packets) will be:
- // outstanding = 640-320-6(due to SACK block)-3 = 311
- //
- // The last 3 is due to the fact that the first 3 packets after
- // rtxOffset will be considered lost due to the SACK blocks sent.
- // Receive the retransmit due to partial ack.
-
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
- // Receive the 2 extra packets that should have been retransmitted as
- // those should be considered lost and immediately retransmitted based
- // on the SACK information in the previous ACK sent above.
- for i := 0; i < 2; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize)
- }
-
- // Now we should get 9 more new unsent packets as the cwnd is 323 and
- // outstanding is 311.
- for i := 0; i < 9; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- metricPollFn = func() error {
- // In SACK recovery only the first segment is fast retransmitted when
- // entering recovery.
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
-
- // Acknowledge all pending data to recover point.
- c.SendAck(seq, recover)
-
- // At this point, the cwnd should reset to expected/2 and there are 9
- // packets outstanding.
- //
- // Now in the first iteration since there are 9 packets outstanding.
- // We would expect to get expected/2 - 9 packets. But subsequent
- // iterations will send us expected/2 + 1 (per iteration).
- expected = expected/2 - 9
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(seq, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- if i == 0 {
- // After the first iteration we expect to get the full
- // congestion window worth of packets in every
- // iteration.
- expected += 9
- }
- expected++
- }
-}
-
-// TestRecoveryEntry tests the following two properties of entering recovery:
-// - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK
-// scoreboard but dupack count is still below threshold.
-// - Only enter recovery when at least one more byte of data beyond the highest
-// byte that was outstanding when fast retransmit was last entered is acked.
-func TestRecoveryEntry(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- numPackets := 5
- data := sendAndReceiveWithSACK(t, c, numPackets, false /* enableRACK */)
-
- // Ack #1 packet.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, maxPayload)
-
- // Now SACK #3, #4 and #5 packets. This will simulate a situation where
- // SND.UNA should be considered lost and the sender should enter fast recovery
- // (even though dupack count is still below threshold).
- p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
- p3End := p3Start.Add(maxPayload)
- p4Start := p3End
- p4End := p4Start.Add(maxPayload)
- p5Start := p4End
- p5End := p5Start.Add(maxPayload)
- c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}})
-
- // Expect #2 to be retransmitted.
- c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // SACK recovery must have happened.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- // #2 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- // No RTOs should have fired yet.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Send 4 more packets.
- var r bytes.Reader
- data = append(data, data...)
- r.Reset(data[5*maxPayload : 9*maxPayload])
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- var sackBlocks []header.SACKBlock
- bytesRead := numPackets * maxPayload
- for i := 0; i < 4; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- if i > 0 {
- pStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
- sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)})
- c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks)
- }
- bytesRead += maxPayload
- }
-
- // #6 should be retransmitted after RTO. The sender should NOT enter fast
- // recovery because the highest byte that was outstanding when fast recovery
- // was last entered is #5 packet's end. And the sender requires at least one
- // more byte beyond that (#6 packet start) to be acked to enter recovery.
- c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize)
- c.SendAck(seq, 9*maxPayload)
-
- metricPollFn = func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Only 1 SACK recovery must have happened.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- // #2 and #6 were retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 2},
- // RTO should have fired once.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go
new file mode 100644
index 000000000..a14cff27e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go
@@ -0,0 +1,221 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type segmentElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type segmentList struct {
+ head *segment
+ tail *segment
+}
+
+// Reset resets list l to the empty state.
+func (l *segmentList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *segmentList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *segmentList) Front() *segment {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *segmentList) Back() *segment {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *segmentList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (segmentElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *segmentList) PushFront(e *segment) {
+ linker := segmentElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *segmentList) PushBack(e *segment) {
+ linker := segmentElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *segmentList) PushBackList(m *segmentList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *segmentList) InsertAfter(b, e *segment) {
+ bLinker := segmentElementMapper{}.linkerFor(b)
+ eLinker := segmentElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *segmentList) InsertBefore(a, e *segment) {
+ aLinker := segmentElementMapper{}.linkerFor(a)
+ eLinker := segmentElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *segmentList) Remove(e *segment) {
+ linker := segmentElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ segmentElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ segmentElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type segmentEntry struct {
+ next *segment
+ prev *segment
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) Next() *segment {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) Prev() *segment {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) SetNext(elem *segment) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) SetPrev(elem *segment) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
new file mode 100644
index 000000000..13061d2b1
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -0,0 +1,901 @@
+// automatically generated by stateify.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (c *cubicState) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.cubicState"
+}
+
+func (c *cubicState) StateFields() []string {
+ return []string{
+ "TCPCubicState",
+ "numCongestionEvents",
+ "s",
+ }
+}
+
+func (c *cubicState) beforeSave() {}
+
+// +checklocksignore
+func (c *cubicState) StateSave(stateSinkObject state.Sink) {
+ c.beforeSave()
+ stateSinkObject.Save(0, &c.TCPCubicState)
+ stateSinkObject.Save(1, &c.numCongestionEvents)
+ stateSinkObject.Save(2, &c.s)
+}
+
+func (c *cubicState) afterLoad() {}
+
+// +checklocksignore
+func (c *cubicState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &c.TCPCubicState)
+ stateSourceObject.Load(1, &c.numCongestionEvents)
+ stateSourceObject.Load(2, &c.s)
+}
+
+func (s *SACKInfo) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.SACKInfo"
+}
+
+func (s *SACKInfo) StateFields() []string {
+ return []string{
+ "Blocks",
+ "NumBlocks",
+ }
+}
+
+func (s *SACKInfo) beforeSave() {}
+
+// +checklocksignore
+func (s *SACKInfo) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.Blocks)
+ stateSinkObject.Save(1, &s.NumBlocks)
+}
+
+func (s *SACKInfo) afterLoad() {}
+
+// +checklocksignore
+func (s *SACKInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.Blocks)
+ stateSourceObject.Load(1, &s.NumBlocks)
+}
+
+func (s *sndQueueInfo) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.sndQueueInfo"
+}
+
+func (s *sndQueueInfo) StateFields() []string {
+ return []string{
+ "TCPSndBufState",
+ }
+}
+
+func (s *sndQueueInfo) beforeSave() {}
+
+// +checklocksignore
+func (s *sndQueueInfo) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.TCPSndBufState)
+}
+
+func (s *sndQueueInfo) afterLoad() {}
+
+// +checklocksignore
+func (s *sndQueueInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.TCPSndBufState)
+}
+
+func (r *rcvQueueInfo) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.rcvQueueInfo"
+}
+
+func (r *rcvQueueInfo) StateFields() []string {
+ return []string{
+ "TCPRcvBufState",
+ "rcvQueue",
+ }
+}
+
+func (r *rcvQueueInfo) beforeSave() {}
+
+// +checklocksignore
+func (r *rcvQueueInfo) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.TCPRcvBufState)
+ stateSinkObject.Save(1, &r.rcvQueue)
+}
+
+func (r *rcvQueueInfo) afterLoad() {}
+
+// +checklocksignore
+func (r *rcvQueueInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.TCPRcvBufState)
+ stateSourceObject.LoadWait(1, &r.rcvQueue)
+}
+
+func (a *accepted) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.accepted"
+}
+
+func (a *accepted) StateFields() []string {
+ return []string{
+ "endpoints",
+ "cap",
+ }
+}
+
+func (a *accepted) beforeSave() {}
+
+// +checklocksignore
+func (a *accepted) StateSave(stateSinkObject state.Sink) {
+ a.beforeSave()
+ var endpointsValue []*endpoint
+ endpointsValue = a.saveEndpoints()
+ stateSinkObject.SaveValue(0, endpointsValue)
+ stateSinkObject.Save(1, &a.cap)
+}
+
+func (a *accepted) afterLoad() {}
+
+// +checklocksignore
+func (a *accepted) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(1, &a.cap)
+ stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "TCPEndpointStateInner",
+ "TransportEndpointInfo",
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "uniqueID",
+ "hardError",
+ "lastError",
+ "rcvQueueInfo",
+ "rcvMemUsed",
+ "ownedByUser",
+ "state",
+ "boundNICID",
+ "ttl",
+ "isConnectNotified",
+ "portFlags",
+ "boundBindToDevice",
+ "boundPortFlags",
+ "boundDest",
+ "effectiveNetProtos",
+ "workerRunning",
+ "workerCleanup",
+ "recentTSTime",
+ "shutdownFlags",
+ "tcpRecovery",
+ "sack",
+ "delay",
+ "scoreboard",
+ "segmentQueue",
+ "synRcvdCount",
+ "userMSS",
+ "maxSynRetries",
+ "windowClamp",
+ "sndQueueInfo",
+ "cc",
+ "keepalive",
+ "userTimeout",
+ "deferAccept",
+ "accepted",
+ "rcv",
+ "snd",
+ "connectingAddress",
+ "amss",
+ "sendTOS",
+ "gso",
+ "tcpLingerTimeout",
+ "closed",
+ "txHash",
+ "owner",
+ "ops",
+ "lastOutOfWindowAckTime",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ var stateValue EndpointState
+ stateValue = e.saveState()
+ stateSinkObject.SaveValue(10, stateValue)
+ stateSinkObject.Save(0, &e.TCPEndpointStateInner)
+ stateSinkObject.Save(1, &e.TransportEndpointInfo)
+ stateSinkObject.Save(2, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(3, &e.waiterQueue)
+ stateSinkObject.Save(4, &e.uniqueID)
+ stateSinkObject.Save(5, &e.hardError)
+ stateSinkObject.Save(6, &e.lastError)
+ stateSinkObject.Save(7, &e.rcvQueueInfo)
+ stateSinkObject.Save(8, &e.rcvMemUsed)
+ stateSinkObject.Save(9, &e.ownedByUser)
+ stateSinkObject.Save(11, &e.boundNICID)
+ stateSinkObject.Save(12, &e.ttl)
+ stateSinkObject.Save(13, &e.isConnectNotified)
+ stateSinkObject.Save(14, &e.portFlags)
+ stateSinkObject.Save(15, &e.boundBindToDevice)
+ stateSinkObject.Save(16, &e.boundPortFlags)
+ stateSinkObject.Save(17, &e.boundDest)
+ stateSinkObject.Save(18, &e.effectiveNetProtos)
+ stateSinkObject.Save(19, &e.workerRunning)
+ stateSinkObject.Save(20, &e.workerCleanup)
+ stateSinkObject.Save(21, &e.recentTSTime)
+ stateSinkObject.Save(22, &e.shutdownFlags)
+ stateSinkObject.Save(23, &e.tcpRecovery)
+ stateSinkObject.Save(24, &e.sack)
+ stateSinkObject.Save(25, &e.delay)
+ stateSinkObject.Save(26, &e.scoreboard)
+ stateSinkObject.Save(27, &e.segmentQueue)
+ stateSinkObject.Save(28, &e.synRcvdCount)
+ stateSinkObject.Save(29, &e.userMSS)
+ stateSinkObject.Save(30, &e.maxSynRetries)
+ stateSinkObject.Save(31, &e.windowClamp)
+ stateSinkObject.Save(32, &e.sndQueueInfo)
+ stateSinkObject.Save(33, &e.cc)
+ stateSinkObject.Save(34, &e.keepalive)
+ stateSinkObject.Save(35, &e.userTimeout)
+ stateSinkObject.Save(36, &e.deferAccept)
+ stateSinkObject.Save(37, &e.accepted)
+ stateSinkObject.Save(38, &e.rcv)
+ stateSinkObject.Save(39, &e.snd)
+ stateSinkObject.Save(40, &e.connectingAddress)
+ stateSinkObject.Save(41, &e.amss)
+ stateSinkObject.Save(42, &e.sendTOS)
+ stateSinkObject.Save(43, &e.gso)
+ stateSinkObject.Save(44, &e.tcpLingerTimeout)
+ stateSinkObject.Save(45, &e.closed)
+ stateSinkObject.Save(46, &e.txHash)
+ stateSinkObject.Save(47, &e.owner)
+ stateSinkObject.Save(48, &e.ops)
+ stateSinkObject.Save(49, &e.lastOutOfWindowAckTime)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.TCPEndpointStateInner)
+ stateSourceObject.Load(1, &e.TransportEndpointInfo)
+ stateSourceObject.Load(2, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.LoadWait(3, &e.waiterQueue)
+ stateSourceObject.Load(4, &e.uniqueID)
+ stateSourceObject.Load(5, &e.hardError)
+ stateSourceObject.Load(6, &e.lastError)
+ stateSourceObject.Load(7, &e.rcvQueueInfo)
+ stateSourceObject.Load(8, &e.rcvMemUsed)
+ stateSourceObject.Load(9, &e.ownedByUser)
+ stateSourceObject.Load(11, &e.boundNICID)
+ stateSourceObject.Load(12, &e.ttl)
+ stateSourceObject.Load(13, &e.isConnectNotified)
+ stateSourceObject.Load(14, &e.portFlags)
+ stateSourceObject.Load(15, &e.boundBindToDevice)
+ stateSourceObject.Load(16, &e.boundPortFlags)
+ stateSourceObject.Load(17, &e.boundDest)
+ stateSourceObject.Load(18, &e.effectiveNetProtos)
+ stateSourceObject.Load(19, &e.workerRunning)
+ stateSourceObject.Load(20, &e.workerCleanup)
+ stateSourceObject.Load(21, &e.recentTSTime)
+ stateSourceObject.Load(22, &e.shutdownFlags)
+ stateSourceObject.Load(23, &e.tcpRecovery)
+ stateSourceObject.Load(24, &e.sack)
+ stateSourceObject.Load(25, &e.delay)
+ stateSourceObject.Load(26, &e.scoreboard)
+ stateSourceObject.LoadWait(27, &e.segmentQueue)
+ stateSourceObject.Load(28, &e.synRcvdCount)
+ stateSourceObject.Load(29, &e.userMSS)
+ stateSourceObject.Load(30, &e.maxSynRetries)
+ stateSourceObject.Load(31, &e.windowClamp)
+ stateSourceObject.Load(32, &e.sndQueueInfo)
+ stateSourceObject.Load(33, &e.cc)
+ stateSourceObject.Load(34, &e.keepalive)
+ stateSourceObject.Load(35, &e.userTimeout)
+ stateSourceObject.Load(36, &e.deferAccept)
+ stateSourceObject.Load(37, &e.accepted)
+ stateSourceObject.LoadWait(38, &e.rcv)
+ stateSourceObject.LoadWait(39, &e.snd)
+ stateSourceObject.Load(40, &e.connectingAddress)
+ stateSourceObject.Load(41, &e.amss)
+ stateSourceObject.Load(42, &e.sendTOS)
+ stateSourceObject.Load(43, &e.gso)
+ stateSourceObject.Load(44, &e.tcpLingerTimeout)
+ stateSourceObject.Load(45, &e.closed)
+ stateSourceObject.Load(46, &e.txHash)
+ stateSourceObject.Load(47, &e.owner)
+ stateSourceObject.Load(48, &e.ops)
+ stateSourceObject.Load(49, &e.lastOutOfWindowAckTime)
+ stateSourceObject.LoadValue(10, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) })
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (k *keepalive) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.keepalive"
+}
+
+func (k *keepalive) StateFields() []string {
+ return []string{
+ "idle",
+ "interval",
+ "count",
+ "unacked",
+ }
+}
+
+func (k *keepalive) beforeSave() {}
+
+// +checklocksignore
+func (k *keepalive) StateSave(stateSinkObject state.Sink) {
+ k.beforeSave()
+ stateSinkObject.Save(0, &k.idle)
+ stateSinkObject.Save(1, &k.interval)
+ stateSinkObject.Save(2, &k.count)
+ stateSinkObject.Save(3, &k.unacked)
+}
+
+func (k *keepalive) afterLoad() {}
+
+// +checklocksignore
+func (k *keepalive) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &k.idle)
+ stateSourceObject.Load(1, &k.interval)
+ stateSourceObject.Load(2, &k.count)
+ stateSourceObject.Load(3, &k.unacked)
+}
+
+func (rc *rackControl) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.rackControl"
+}
+
+func (rc *rackControl) StateFields() []string {
+ return []string{
+ "TCPRACKState",
+ "exitedRecovery",
+ "minRTT",
+ "tlpRxtOut",
+ "tlpHighRxt",
+ "snd",
+ }
+}
+
+func (rc *rackControl) beforeSave() {}
+
+// +checklocksignore
+func (rc *rackControl) StateSave(stateSinkObject state.Sink) {
+ rc.beforeSave()
+ stateSinkObject.Save(0, &rc.TCPRACKState)
+ stateSinkObject.Save(1, &rc.exitedRecovery)
+ stateSinkObject.Save(2, &rc.minRTT)
+ stateSinkObject.Save(3, &rc.tlpRxtOut)
+ stateSinkObject.Save(4, &rc.tlpHighRxt)
+ stateSinkObject.Save(5, &rc.snd)
+}
+
+func (rc *rackControl) afterLoad() {}
+
+// +checklocksignore
+func (rc *rackControl) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &rc.TCPRACKState)
+ stateSourceObject.Load(1, &rc.exitedRecovery)
+ stateSourceObject.Load(2, &rc.minRTT)
+ stateSourceObject.Load(3, &rc.tlpRxtOut)
+ stateSourceObject.Load(4, &rc.tlpHighRxt)
+ stateSourceObject.Load(5, &rc.snd)
+}
+
+func (r *receiver) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.receiver"
+}
+
+func (r *receiver) StateFields() []string {
+ return []string{
+ "TCPReceiverState",
+ "ep",
+ "rcvWnd",
+ "rcvWUP",
+ "prevBufUsed",
+ "closed",
+ "pendingRcvdSegments",
+ "lastRcvdAckTime",
+ }
+}
+
+func (r *receiver) beforeSave() {}
+
+// +checklocksignore
+func (r *receiver) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.TCPReceiverState)
+ stateSinkObject.Save(1, &r.ep)
+ stateSinkObject.Save(2, &r.rcvWnd)
+ stateSinkObject.Save(3, &r.rcvWUP)
+ stateSinkObject.Save(4, &r.prevBufUsed)
+ stateSinkObject.Save(5, &r.closed)
+ stateSinkObject.Save(6, &r.pendingRcvdSegments)
+ stateSinkObject.Save(7, &r.lastRcvdAckTime)
+}
+
+func (r *receiver) afterLoad() {}
+
+// +checklocksignore
+func (r *receiver) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.TCPReceiverState)
+ stateSourceObject.Load(1, &r.ep)
+ stateSourceObject.Load(2, &r.rcvWnd)
+ stateSourceObject.Load(3, &r.rcvWUP)
+ stateSourceObject.Load(4, &r.prevBufUsed)
+ stateSourceObject.Load(5, &r.closed)
+ stateSourceObject.Load(6, &r.pendingRcvdSegments)
+ stateSourceObject.Load(7, &r.lastRcvdAckTime)
+}
+
+func (r *renoState) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.renoState"
+}
+
+func (r *renoState) StateFields() []string {
+ return []string{
+ "s",
+ }
+}
+
+func (r *renoState) beforeSave() {}
+
+// +checklocksignore
+func (r *renoState) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.s)
+}
+
+func (r *renoState) afterLoad() {}
+
+// +checklocksignore
+func (r *renoState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.s)
+}
+
+func (rr *renoRecovery) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.renoRecovery"
+}
+
+func (rr *renoRecovery) StateFields() []string {
+ return []string{
+ "s",
+ }
+}
+
+func (rr *renoRecovery) beforeSave() {}
+
+// +checklocksignore
+func (rr *renoRecovery) StateSave(stateSinkObject state.Sink) {
+ rr.beforeSave()
+ stateSinkObject.Save(0, &rr.s)
+}
+
+func (rr *renoRecovery) afterLoad() {}
+
+// +checklocksignore
+func (rr *renoRecovery) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &rr.s)
+}
+
+func (sr *sackRecovery) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.sackRecovery"
+}
+
+func (sr *sackRecovery) StateFields() []string {
+ return []string{
+ "s",
+ }
+}
+
+func (sr *sackRecovery) beforeSave() {}
+
+// +checklocksignore
+func (sr *sackRecovery) StateSave(stateSinkObject state.Sink) {
+ sr.beforeSave()
+ stateSinkObject.Save(0, &sr.s)
+}
+
+func (sr *sackRecovery) afterLoad() {}
+
+// +checklocksignore
+func (sr *sackRecovery) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &sr.s)
+}
+
+func (s *SACKScoreboard) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.SACKScoreboard"
+}
+
+func (s *SACKScoreboard) StateFields() []string {
+ return []string{
+ "smss",
+ "maxSACKED",
+ }
+}
+
+func (s *SACKScoreboard) beforeSave() {}
+
+// +checklocksignore
+func (s *SACKScoreboard) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.smss)
+ stateSinkObject.Save(1, &s.maxSACKED)
+}
+
+func (s *SACKScoreboard) afterLoad() {}
+
+// +checklocksignore
+func (s *SACKScoreboard) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.smss)
+ stateSourceObject.Load(1, &s.maxSACKED)
+}
+
+func (s *segment) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segment"
+}
+
+func (s *segment) StateFields() []string {
+ return []string{
+ "segmentEntry",
+ "refCnt",
+ "ep",
+ "qFlags",
+ "srcAddr",
+ "dstAddr",
+ "netProto",
+ "nicID",
+ "data",
+ "hdr",
+ "sequenceNumber",
+ "ackNumber",
+ "flags",
+ "window",
+ "csum",
+ "csumValid",
+ "parsedOptions",
+ "options",
+ "hasNewSACKInfo",
+ "rcvdTime",
+ "xmitTime",
+ "xmitCount",
+ "acked",
+ "dataMemSize",
+ "lost",
+ }
+}
+
+func (s *segment) beforeSave() {}
+
+// +checklocksignore
+func (s *segment) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = s.saveData()
+ stateSinkObject.SaveValue(8, dataValue)
+ var optionsValue []byte
+ optionsValue = s.saveOptions()
+ stateSinkObject.SaveValue(17, optionsValue)
+ stateSinkObject.Save(0, &s.segmentEntry)
+ stateSinkObject.Save(1, &s.refCnt)
+ stateSinkObject.Save(2, &s.ep)
+ stateSinkObject.Save(3, &s.qFlags)
+ stateSinkObject.Save(4, &s.srcAddr)
+ stateSinkObject.Save(5, &s.dstAddr)
+ stateSinkObject.Save(6, &s.netProto)
+ stateSinkObject.Save(7, &s.nicID)
+ stateSinkObject.Save(9, &s.hdr)
+ stateSinkObject.Save(10, &s.sequenceNumber)
+ stateSinkObject.Save(11, &s.ackNumber)
+ stateSinkObject.Save(12, &s.flags)
+ stateSinkObject.Save(13, &s.window)
+ stateSinkObject.Save(14, &s.csum)
+ stateSinkObject.Save(15, &s.csumValid)
+ stateSinkObject.Save(16, &s.parsedOptions)
+ stateSinkObject.Save(18, &s.hasNewSACKInfo)
+ stateSinkObject.Save(19, &s.rcvdTime)
+ stateSinkObject.Save(20, &s.xmitTime)
+ stateSinkObject.Save(21, &s.xmitCount)
+ stateSinkObject.Save(22, &s.acked)
+ stateSinkObject.Save(23, &s.dataMemSize)
+ stateSinkObject.Save(24, &s.lost)
+}
+
+func (s *segment) afterLoad() {}
+
+// +checklocksignore
+func (s *segment) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.segmentEntry)
+ stateSourceObject.Load(1, &s.refCnt)
+ stateSourceObject.Load(2, &s.ep)
+ stateSourceObject.Load(3, &s.qFlags)
+ stateSourceObject.Load(4, &s.srcAddr)
+ stateSourceObject.Load(5, &s.dstAddr)
+ stateSourceObject.Load(6, &s.netProto)
+ stateSourceObject.Load(7, &s.nicID)
+ stateSourceObject.Load(9, &s.hdr)
+ stateSourceObject.Load(10, &s.sequenceNumber)
+ stateSourceObject.Load(11, &s.ackNumber)
+ stateSourceObject.Load(12, &s.flags)
+ stateSourceObject.Load(13, &s.window)
+ stateSourceObject.Load(14, &s.csum)
+ stateSourceObject.Load(15, &s.csumValid)
+ stateSourceObject.Load(16, &s.parsedOptions)
+ stateSourceObject.Load(18, &s.hasNewSACKInfo)
+ stateSourceObject.Load(19, &s.rcvdTime)
+ stateSourceObject.Load(20, &s.xmitTime)
+ stateSourceObject.Load(21, &s.xmitCount)
+ stateSourceObject.Load(22, &s.acked)
+ stateSourceObject.Load(23, &s.dataMemSize)
+ stateSourceObject.Load(24, &s.lost)
+ stateSourceObject.LoadValue(8, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(17, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
+}
+
+func (q *segmentQueue) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segmentQueue"
+}
+
+func (q *segmentQueue) StateFields() []string {
+ return []string{
+ "list",
+ "ep",
+ "frozen",
+ }
+}
+
+func (q *segmentQueue) beforeSave() {}
+
+// +checklocksignore
+func (q *segmentQueue) StateSave(stateSinkObject state.Sink) {
+ q.beforeSave()
+ stateSinkObject.Save(0, &q.list)
+ stateSinkObject.Save(1, &q.ep)
+ stateSinkObject.Save(2, &q.frozen)
+}
+
+func (q *segmentQueue) afterLoad() {}
+
+// +checklocksignore
+func (q *segmentQueue) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.LoadWait(0, &q.list)
+ stateSourceObject.Load(1, &q.ep)
+ stateSourceObject.Load(2, &q.frozen)
+}
+
+func (s *sender) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.sender"
+}
+
+func (s *sender) StateFields() []string {
+ return []string{
+ "TCPSenderState",
+ "ep",
+ "lr",
+ "firstRetransmittedSegXmitTime",
+ "writeNext",
+ "writeList",
+ "rtt",
+ "minRTO",
+ "maxRTO",
+ "maxRetries",
+ "gso",
+ "state",
+ "cc",
+ "rc",
+ }
+}
+
+func (s *sender) beforeSave() {}
+
+// +checklocksignore
+func (s *sender) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.TCPSenderState)
+ stateSinkObject.Save(1, &s.ep)
+ stateSinkObject.Save(2, &s.lr)
+ stateSinkObject.Save(3, &s.firstRetransmittedSegXmitTime)
+ stateSinkObject.Save(4, &s.writeNext)
+ stateSinkObject.Save(5, &s.writeList)
+ stateSinkObject.Save(6, &s.rtt)
+ stateSinkObject.Save(7, &s.minRTO)
+ stateSinkObject.Save(8, &s.maxRTO)
+ stateSinkObject.Save(9, &s.maxRetries)
+ stateSinkObject.Save(10, &s.gso)
+ stateSinkObject.Save(11, &s.state)
+ stateSinkObject.Save(12, &s.cc)
+ stateSinkObject.Save(13, &s.rc)
+}
+
+func (s *sender) afterLoad() {}
+
+// +checklocksignore
+func (s *sender) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.TCPSenderState)
+ stateSourceObject.Load(1, &s.ep)
+ stateSourceObject.Load(2, &s.lr)
+ stateSourceObject.Load(3, &s.firstRetransmittedSegXmitTime)
+ stateSourceObject.Load(4, &s.writeNext)
+ stateSourceObject.Load(5, &s.writeList)
+ stateSourceObject.Load(6, &s.rtt)
+ stateSourceObject.Load(7, &s.minRTO)
+ stateSourceObject.Load(8, &s.maxRTO)
+ stateSourceObject.Load(9, &s.maxRetries)
+ stateSourceObject.Load(10, &s.gso)
+ stateSourceObject.Load(11, &s.state)
+ stateSourceObject.Load(12, &s.cc)
+ stateSourceObject.Load(13, &s.rc)
+}
+
+func (r *rtt) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.rtt"
+}
+
+func (r *rtt) StateFields() []string {
+ return []string{
+ "TCPRTTState",
+ }
+}
+
+func (r *rtt) beforeSave() {}
+
+// +checklocksignore
+func (r *rtt) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.TCPRTTState)
+}
+
+func (r *rtt) afterLoad() {}
+
+// +checklocksignore
+func (r *rtt) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.TCPRTTState)
+}
+
+func (l *endpointList) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.endpointList"
+}
+
+func (l *endpointList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *endpointList) beforeSave() {}
+
+// +checklocksignore
+func (l *endpointList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *endpointList) afterLoad() {}
+
+// +checklocksignore
+func (l *endpointList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *endpointEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.endpointEntry"
+}
+
+func (e *endpointEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *endpointEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *endpointEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *endpointEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *endpointEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func (l *segmentList) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segmentList"
+}
+
+func (l *segmentList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *segmentList) beforeSave() {}
+
+// +checklocksignore
+func (l *segmentList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *segmentList) afterLoad() {}
+
+// +checklocksignore
+func (l *segmentList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *segmentEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segmentEntry"
+}
+
+func (e *segmentEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *segmentEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *segmentEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *segmentEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *segmentEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*cubicState)(nil))
+ state.Register((*SACKInfo)(nil))
+ state.Register((*sndQueueInfo)(nil))
+ state.Register((*rcvQueueInfo)(nil))
+ state.Register((*accepted)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*keepalive)(nil))
+ state.Register((*rackControl)(nil))
+ state.Register((*receiver)(nil))
+ state.Register((*renoState)(nil))
+ state.Register((*renoRecovery)(nil))
+ state.Register((*sackRecovery)(nil))
+ state.Register((*SACKScoreboard)(nil))
+ state.Register((*segment)(nil))
+ state.Register((*segmentQueue)(nil))
+ state.Register((*sender)(nil))
+ state.Register((*rtt)(nil))
+ state.Register((*endpointList)(nil))
+ state.Register((*endpointEntry)(nil))
+ state.Register((*segmentList)(nil))
+ state.Register((*segmentEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
deleted file mode 100644
index 6f1ee3816..000000000
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ /dev/null
@@ -1,8602 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "math"
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// endpointTester provides helper functions to test a tcpip.Endpoint.
-type endpointTester struct {
- ep tcpip.Endpoint
-}
-
-// CheckReadError issues a read to the endpoint and checking for an error.
-func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) {
- t.Helper()
- res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if got != want {
- t.Fatalf("ep.Read = %s, want %s", got, want)
- }
- if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" {
- t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff)
- }
-}
-
-// CheckRead issues a read to the endpoint and checking for a success, returning
-// the data read.
-func (e *endpointTester) CheckRead(t *testing.T) []byte {
- t.Helper()
- var buf bytes.Buffer
- res, err := e.ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("ep.Read = _, %s; want _, nil", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- return buf.Bytes()
-}
-
-// CheckReadFull reads from the endpoint for exactly count bytes.
-func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
- t.Helper()
- var buf bytes.Buffer
- w := tcpip.LimitedWriter{
- W: &buf,
- N: int64(count),
- }
- for w.N != 0 {
- _, err := e.ep.Read(&w, tcpip.ReadOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for receive to be notified.
- select {
- case <-notifyRead:
- case <-time.After(timeout):
- t.Fatalf("Timed out waiting for data to arrive")
- }
- continue
- } else if err != nil {
- t.Fatalf("ep.Read = _, %s; want _, nil", err)
- }
- }
- return buf.Bytes()
-}
-
-const (
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65535
-
- // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an
- // IPv4 endpoint when the MTU is set to defaultMTU in the test.
- defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
-)
-
-func TestGiveUpConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Register for notification, then start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventHUp)
- defer wq.EventUnregister(&waitEntry)
-
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- // Close the connection, wait for completion.
- ep.Close()
-
- // Wait for ep to become writable.
- <-notifyCh
-
- // Call Connect again to retreive the handshake failure status
- // and stats updates.
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
- }
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
-}
-
-// Test for ICMP error handling without completing handshake.
-func TestConnectICMPError(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventHUp)
- defer wq.EventUnregister(&waitEntry)
-
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- syn := c.GetPacket()
- checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn)))
-
- wep := ep.(interface {
- StopWork()
- ResumeWork()
- LastErrorLocked() tcpip.Error
- })
-
- // Stop the protocol loop, ensure that the ICMP error is processed and
- // the last ICMP error is read before the loop is resumed. This sanity
- // tests the handshake completion logic on ICMP errors.
- wep.StopWork()
-
- c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU)
-
- for {
- if err := wep.LastErrorLocked(); err != nil {
- if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
- t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d)
- }
- break
- }
- time.Sleep(time.Millisecond)
- }
-
- wep.ResumeWork()
-
- <-notifyCh
-
- // The stack would have unregistered the endpoint because of the ICMP error.
- // Expect a RST for any subsequent packets sent to the endpoint.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1,
- AckNum: c.IRS + 1,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestConnectIncrementActiveConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ActiveConnectionOpenings.Value() + 1
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.FailedConnectionAttempts.Value()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
- }
-}
-
-func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
- want := stats.TCP.FailedConnectionAttempts.Value() + 1
-
- {
- err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
- t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
- }
-}
-
-func TestCloseWithoutConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- c.EP.Close()
-
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func TestTCPSegmentsSentIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- // SYN and ACK
- want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestTCPResetsSentIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- stats := c.Stack().Stats()
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- want := stats.TCP.SegmentsSent.Value() + 1
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- // If the AckNum is not the increment of the last sequence number, a RST
- // segment is sent back in response.
- AckNum: c.IRS + 2,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- c.GetPacket()
-
- metricPollFn := func() error {
- if got := stats.TCP.ResetsSent.Value(); got != want {
- return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestTCPResetsSentNoICMP confirms that we don't get an ICMP
-// DstUnreachable packet when we try send a packet which is not part
-// of an active session.
-func TestTCPResetsSentNoICMP(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- stats := c.Stack().Stats()
-
- // Send a SYN request for a closed port. This should elicit an RST
- // but NOT an ICMPv4 DstUnreachable packet.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive whatever comes back.
- b := c.GetPacket()
- ipHdr := header.IPv4(b)
- if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want {
- t.Errorf("unexpected protocol, got = %d, want = %d", got, want)
- }
-
- // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
- sent := stats.ICMP.V4.PacketsSent
- if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
- t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
- }
-}
-
-// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
-// a RST if an ACK is received on the listening socket for which there is no
-// active handshake in progress and we are not using SYN cookies.
-func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Lower stackwide TIME_WAIT timeout so that the reservations
- // are released instantly on Close.
- tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err)
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- c.GetPacket()
-
- // Since an active close was done we need to wait for a little more than
- // tcpLingerTimeout for the port reservations to be released and the
- // socket to move to a CLOSED state.
- time.Sleep(20 * time.Millisecond)
-
- // Now resend the same ACK, this ACK should generate a RST as there
- // should be no endpoint in SYN-RCVD state and we are not using
- // syn-cookies yet. The reason we send the same ACK is we need a valid
- // cookie(IRS) generated by the netstack without which the ACK will be
- // rejected.
- c.SendPacket(nil, ackHeaders)
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestTCPResetsReceivedIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- Flags: header.TCPFlagRst,
- })
-
- if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestTCPResetsDoNotGenerateResets(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- Flags: header.TCPFlagRst,
- })
-
- if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
- }
- c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
-}
-
-func TestActiveHandshake(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-}
-
-func TestNonBlockingClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- // Close the endpoint and measure how long it takes.
- t0 := time.Now()
- ep.Close()
- if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %s", diff)
- }
-}
-
-func TestConnectResetAfterClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPLinger to 3 seconds so that sockets are marked closed
- // after 3 second in FIN_WAIT2 state.
- tcpLingerTimeout := 3 * time.Second
- opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- // Close the endpoint, make sure we get a FIN segment, then acknowledge
- // to complete closure of sender, but don't send our own FIN.
- ep.Close()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Wait for the ep to give up waiting for a FIN.
- time.Sleep(tcpLingerTimeout + 1*time.Second)
-
- // Now send an ACK and it should trigger a RST as the endpoint should
- // not exist anymore.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- for {
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin {
- // This is a retransmit of the FIN, ignore it.
- continue
- }
-
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- // RST is always generated with sndNxt which if the FIN
- // has been sent will be 1 higher than the sequence number
- // of the FIN itself.
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
- break
- }
-}
-
-// TestCurrentConnectedIncrement tests increment of the current
-// established and connected counters.
-func TestCurrentConnectedIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
- // after 1 second in TIME_WAIT state.
- tcpTimeWaitTimeout := 1 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
- }
- gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
- if gotConnected != 1 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
- }
-
- ep.Close()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
- }
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Wait for a little more than the TIME-WAIT duration for the socket to
- // transition to CLOSED state.
- time.Sleep(1200 * time.Millisecond)
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
-// when the endpoint transitions to StateClose. The in-flight segments would be
-// re-enqueued to a any listening endpoint.
-func TestClosingWithEnqueuedSegments(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
- t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
- }
-
- // Send a FIN for ESTABLISHED --> CLOSED-WAIT
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Get the ACK for the FIN we sent.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Give the stack a few ms to transition the endpoint out of ESTABLISHED
- // state.
- time.Sleep(10 * time.Millisecond)
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
- t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
- }
-
- // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
- ep.Close()
-
- // Get the FIN
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Pause the endpoint`s protocolMainLoop.
- ep.(interface{ StopWork() }).StopWork()
-
- // Enqueue last ACK followed by an ACK matching the endpoint
- //
- // Send Last ACK for LAST_ACK --> CLOSED
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Send a packet with ACK set, this would generate RST when
- // not using SYN cookies as in this test.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss.Add(2),
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Unpause endpoint`s protocolMainLoop.
- ep.(interface{ ResumeWork() }).ResumeWork()
-
- // Wait for the protocolMainLoop to resume and update state.
- time.Sleep(10 * time.Millisecond)
-
- // Expect the endpoint to be closed.
- if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
- }
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
-
- // Check if the endpoint was moved to CLOSED and netstack a reset in
- // response to the ACK packet that we sent after last-ACK.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-}
-
-func TestSimpleReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Receive data.
- v := ept.CheckRead(t)
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- // Check that ACK is received.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when
-// creating a new active TCP socket. It should be present in the sent TCP
-// SYN segment.
-func TestUserSuppliedMSSOnConnect(t *testing.T) {
- const mtu = 5000
-
- ips := []struct {
- name string
- createEP func(*context.Context)
- connectAddr tcpip.Address
- checker func(*testing.T, *context.Context, uint16, int)
- maxMSS uint16
- }{
- {
- name: "IPv4",
- createEP: func(c *context.Context) {
- c.Create(-1)
- },
- connectAddr: context.TestAddr,
- checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
- },
- maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
- },
- {
- name: "IPv6",
- createEP: func(c *context.Context) {
- c.CreateV6Endpoint(true)
- },
- connectAddr: context.TestV6Addr,
- checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
- },
- maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
- },
- }
-
- for _, ip := range ips {
- t.Run(ip.name, func(t *testing.T) {
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
- }{
- {
- name: "EqualToMaxMSS",
- setMSS: ip.maxMSS,
- expMSS: ip.maxMSS,
- },
- {
- name: "LessThanMaxMSS",
- setMSS: ip.maxMSS - 1,
- expMSS: ip.maxMSS - 1,
- },
- {
- name: "GreaterThanMaxMSS",
- setMSS: ip.maxMSS + 1,
- expMSS: ip.maxMSS,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- ip.createEP(c)
-
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
- }
-
- // Get expected window size.
- rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize()
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
-
- connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
- {
- err := c.EP.Connect(connectAddr)
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d)
- }
- }
-
- // Receive SYN packet with our user supplied MSS.
- ip.checker(t, c, test.expMSS, ws)
- })
- }
- })
- }
-}
-
-// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used
-// when completing the handshake for a new TCP connection from a TCP
-// listening socket. It should be present in the sent TCP SYN-ACK segment.
-func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
- const mtu = 5000
-
- ips := []struct {
- name string
- createEP func(*context.Context)
- sendPkt func(*context.Context, *context.Headers)
- checker func(*testing.T, *context.Context, uint16, uint16)
- maxMSS uint16
- }{
- {
- name: "IPv4",
- createEP: func(c *context.Context) {
- c.Create(-1)
- },
- sendPkt: func(c *context.Context, h *context.Headers) {
- c.SendPacket(nil, h)
- },
- checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
- },
- maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
- },
- {
- name: "IPv6",
- createEP: func(c *context.Context) {
- c.CreateV6Endpoint(false)
- },
- sendPkt: func(c *context.Context, h *context.Headers) {
- c.SendV6Packet(nil, h)
- },
- checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
- },
- maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
- },
- }
-
- for _, ip := range ips {
- t.Run(ip.name, func(t *testing.T) {
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
- }{
- {
- name: "EqualToMaxMSS",
- setMSS: ip.maxMSS,
- expMSS: ip.maxMSS,
- },
- {
- name: "LessThanMaxMSS",
- setMSS: ip.maxMSS - 1,
- expMSS: ip.maxMSS - 1,
- },
- {
- name: "GreaterThanMaxMSS",
- setMSS: ip.maxMSS + 1,
- expMSS: ip.maxMSS,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- ip.createEP(c)
-
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
- }
-
- bindAddr := tcpip.FullAddress{Port: context.StackPort}
- if err := c.EP.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s:", bindAddr, err)
- }
-
- backlog := 5
- // Keep the number of client requests twice to the backlog
- // such that half of the connections do not use syncookies
- // and the other half does.
- clientConnects := backlog * 2
-
- if err := c.EP.Listen(backlog); err != nil {
- t.Fatalf("Listen(%d): %s:", backlog, err)
- }
-
- for i := 0; i < clientConnects; i++ {
- // Send a SYN requests.
- iss := seqnum.Value(i)
- srcPort := context.TestPort + uint16(i)
- ip.sendPkt(c, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- ip.checker(t, c, srcPort, test.expMSS)
- }
- })
- }
- })
- }
-}
-func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
-// peers can send data and expect a response within a reasonable ammount of time
-// without calling Accept on the listening endpoint first.
-//
-// This test uses IPv4.
-func TestTCPAckBeforeAcceptV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- // Send data before accepting the connection.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
-// peers can send data and expect a response within a reasonable ammount of time
-// without calling Accept on the listening endpoint first.
-//
-// This test uses IPv6.
-func TestTCPAckBeforeAcceptV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- // Send data before accepting the connection.
- c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-func TestSendRstOnListenerRxAckV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-func TestSendRstOnListenerRxAckV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true /* v6Only */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-// TestListenShutdown tests for the listening endpoint replying with RST
-// on read shutdown.
-func TestListenShutdown(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(1 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatal("Shutdown failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: 100,
- AckNum: 200,
- })
-
- // Expect the listening endpoint to reset the connection.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- ))
-}
-
-var _ waiter.EntryCallback = (callback)(nil)
-
-type callback func(*waiter.Entry, waiter.EventMask)
-
-func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) {
- cb(entry, mask)
-}
-
-func TestListenerReadinessOnEvent(t *testing.T) {
- s := stack.New(stack.Options{
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- {
- ep := loopback.New()
- if testing.Verbose() {
- ep = sniffer.New(ep)
- }
- const id = 1
- if err := s.CreateNIC(id, ep); err != nil {
- t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
- }
- if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: header.IPv4EmptySubnet, NIC: id},
- })
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil {
- t.Fatalf("Bind(%s): %s", context.StackAddr, err)
- }
- const backlog = 1
- if err := ep.Listen(backlog); err != nil {
- t.Fatalf("Listen(%d): %s", backlog, err)
- }
-
- address, err := ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress(): %s", err)
- }
-
- conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
- }
- defer conn.Close()
-
- events := make(chan waiter.EventMask)
- // Scope `entry` to allow a binding of the same name below.
- {
- entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) {
- events <- ep.Readiness(mask)
- })}
- wq.EventRegister(&entry, waiter.EventIn)
- defer wq.EventUnregister(&entry)
- }
-
- entry, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&entry, waiter.EventOut)
- defer wq.EventUnregister(&entry)
-
- switch err := conn.Connect(address).(type) {
- case *tcpip.ErrConnectStarted:
- default:
- t.Fatalf("Connect(%#v): %v", address, err)
- }
-
- // Read at least one event.
- got := <-events
- for {
- select {
- case event := <-events:
- got |= event
- continue
- case <-ch:
- if want := waiter.ReadableEvents; got != want {
- t.Errorf("observed events = %b, want %b", got, want)
- }
- }
- break
- }
-}
-
-// TestListenCloseWhileConnect tests for the listening endpoint to
-// drain the accept-queue when closed. This should reset all of the
-// pending connections that are waiting to be accepted.
-func TestListenCloseWhileConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(1 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- // Wait for the new endpoint created because of handshake to be delivered
- // to the listening endpoint's accept queue.
- <-notifyCh
-
- // Close the listening endpoint.
- c.EP.Close()
-
- // Expect the listening endpoint to reset the connection.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- ))
-}
-
-func TestTOSV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
-
- const tos = 0xC0
- if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
- t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err)
- }
-
- v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err)
- }
-
- if v != tos {
- t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos)
- }
-
- testV4Connect(t, c, checker.TOS(tos, 0))
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- checker.TOS(tos, 0),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-}
-
-func TestTrafficClassV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- const tos = 0xC0
- if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil {
- t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err)
- }
-
- v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err)
- }
-
- if v != tos {
- t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos)
- }
-
- // Test the connection request.
- testV6Connect(t, c, checker.TOS(tos, 0))
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetV6Packet()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv6(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- checker.TOS(tos, 0),
- )
-
- if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-}
-
-func TestConnectBindToDevice(t *testing.T) {
- for _, test := range []struct {
- name string
- device tcpip.NICID
- want tcp.EndpointState
- }{
- {"RightDevice", 1, tcp.StateEstablished},
- {"WrongDevice", 2, tcp.StateSynSent},
- {"AnyDevice", 0, tcp.StateEstablished},
- } {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
- if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err)
- }
- // Start connection attempt.
- waitEntry, _ := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
-
- c.GetPacket()
- if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
- }
- })
- }
-}
-
-func TestShutdownConnectingSocket(t *testing.T) {
- for _, test := range []struct {
- name string
- shutdownMode tcpip.ShutdownFlags
- }{
- {"ShutdownRead", tcpip.ShutdownRead},
- {"ShutdownWrite", tcpip.ShutdownWrite},
- {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite},
- } {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create an endpoint, don't handshake because we want to interfere with
- // the handshake process.
- c.Create(-1)
-
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- // Start connection attempt.
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" {
- t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Check the SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
-
- if err := c.EP.Shutdown(test.shutdownMode); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- // The endpoint internal state is updated immediately.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
-
- select {
- case <-ch:
- default:
- t.Fatal("endpoint was not notified")
- }
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrConnectionReset{})
-
- // If the endpoint is not properly shutdown, it'll re-attempt to connect
- // by sending another ACK packet.
- c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond))
- })
- }
-}
-
-func TestSynSent(t *testing.T) {
- for _, test := range []struct {
- name string
- reset bool
- }{
- {"RstOnSynSent", true},
- {"CloseOnSynSent", false},
- } {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
-
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- err := c.EP.Connect(addr)
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- if test.reset {
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to error and close out.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
- } else {
- c.EP.Close()
- }
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
-
- ept := endpointTester{c.EP}
- if test.reset {
- ept.CheckReadError(t, &tcpip.ErrConnectionRefused{})
- } else {
- ept.CheckReadError(t, &tcpip.ErrAborted{})
- }
-
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- })
- }
-}
-
-func TestOutOfOrderReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send second half of data first, with seqnum 3 ahead of expected.
- data := []byte{1, 2, 3, 4, 5, 6}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(3),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that we get an ACK specifying which seqnum is expected.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Wait 200ms and check that no data has been received.
- time.Sleep(200 * time.Millisecond)
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send the first 3 bytes now.
- c.SendPacket(data[:3], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive data.
- read := ept.CheckReadFull(t, 6, ch, 5*time.Second)
-
- // Check that we received the data in proper order.
- if !bytes.Equal(data, read) {
- t.Fatalf("got data = %v, want = %v", read, data)
- }
-
- // Check that the whole data is acknowledged.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestOutOfOrderFlood(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- rcvBufSz := math.MaxUint16
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send 100 packets before the actual one that is expected.
- data := []byte{1, 2, 3, 4, 5, 6}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := 0; i < 100; i++ {
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(6),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Send packet with seqnum as initial + 3. It must be discarded because the
- // out-of-order buffer was filled by the previous packets.
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(3),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Now send the expected packet with initial sequence number.
- c.SendPacket(data[:3], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that only packet with initial sequence number is acknowledged.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+3),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestRstOnCloseWithUnreadData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that ACK is received, this happens regardless of the read.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Now that we know we have unread data, let's just close the connection
- // and verify that netstack sends an RST rather than a FIN.
- c.EP.Close()
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // We shouldn't consume a sequence number on RST.
- checker.TCPSeqNum(uint32(c.IRS)+1),
- ))
- // The RST puts the endpoint into an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // This final ACK should be ignored because an ACK on a reset doesn't mean
- // anything.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(len(data))),
- AckNum: c.IRS.Add(seqnum.Size(2)),
- RcvWnd: 30000,
- })
-}
-
-func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that ACK is received, this happens regardless of the read.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Cause a FIN to be generated.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- // Make sure we get the FIN but DON't ACK IT.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- ))
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Cause a RST to be generated by closing the read end now since we have
- // unread data.
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- // Make sure we get the RST
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // RST is always generated with sndNxt which if the FIN
- // has been sent will be 1 higher than the sequence
- // number of the FIN itself.
- checker.TCPSeqNum(uint32(c.IRS)+2),
- ))
- // The RST puts the endpoint into an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // The ACK to the FIN should now be rejected since the connection has been
- // closed by a RST.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(len(data))),
- AckNum: c.IRS.Add(seqnum.Size(2)),
- RcvWnd: 30000,
- })
-}
-
-func TestShutdownRead(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
- var want uint64 = 1
- if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
- }
-}
-
-func TestFullWindowReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBufSz = 10
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
- // the provided buffer value by tcp.SegOverheadFactor to calculate the actual
- // receive buffer size.
- data := make([]byte, tcp.SegOverheadFactor*rcvBufSz)
- for i := range data {
- data[i] = byte(i % 255)
- }
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that data is acknowledged, and window goes to zero.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(0),
- ),
- )
-
- // Receive data and check it.
- v := ept.CheckRead(t)
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- var want uint64 = 1
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
- }
-
- // Check that we get an ACK for the newly non-zero window.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(10),
- ),
- )
-}
-
-func TestSmallReceiveBufferReadiness(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- ep := loopback.New()
- if testing.Verbose() {
- ep = sniffer.New(ep)
- }
-
- const nicID = 1
- nicOpts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x7f\x00\x00\x01"),
- PrefixLen: 8,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
- if err != nil {
- t.Fatalf("tcpip.NewSubnet failed: %s", err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: subnet,
- NIC: nicID,
- },
- })
- }
-
- listenerEntry, listenerCh := waiter.NewChannelEntry(nil)
- var listenerWQ waiter.Queue
- listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer listener.Close()
- listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents)
- defer listenerWQ.EventUnregister(&listenerEntry)
-
- if err := listener.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := listener.Listen(1); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- localAddress, err := listener.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress failed: %s", err)
- }
-
- for i := 8; i > 0; i /= 2 {
- size := int64(i << 10)
- t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) {
- var clientWQ waiter.Queue
- client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer client.Close()
- switch err := client.Connect(localAddress).(type) {
- case nil:
- t.Fatal("Connect returned nil error")
- case *tcpip.ErrConnectStarted:
- default:
- t.Fatalf("Connect failed: %s", err)
- }
-
- <-listenerCh
- server, serverWQ, err := listener.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
- defer server.Close()
-
- client.SocketOptions().SetReceiveBufferSize(size, true)
- // Send buffer size doesn't seem to affect this test.
- // server.SocketOptions().SetSendBufferSize(size, true)
-
- clientEntry, clientCh := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents)
- defer clientWQ.EventUnregister(&clientEntry)
-
- serverEntry, serverCh := waiter.NewChannelEntry(nil)
- serverWQ.EventRegister(&serverEntry, waiter.WritableEvents)
- defer serverWQ.EventUnregister(&serverEntry)
-
- var total int64
- for {
- var b [64 << 10]byte
- var r bytes.Reader
- r.Reset(b[:])
- switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
- case nil:
- t.Logf("wrote %d bytes", n)
- total += n
- continue
- case *tcpip.ErrWouldBlock:
- select {
- case <-serverCh:
- continue
- case <-time.After(100 * time.Millisecond):
- // Well and truly full.
- t.Logf("send and receive queues are full")
- }
- default:
- t.Fatalf("Write failed: %s", err)
- }
- break
- }
- t.Logf("wrote %d bytes in total", total)
-
- var wg sync.WaitGroup
- defer wg.Wait()
-
- wg.Add(2)
- go func() {
- defer wg.Done()
-
- var b [64 << 10]byte
- var r bytes.Reader
- r.Reset(b[:])
- if err := func() error {
- var total int64
- defer t.Logf("wrote %d bytes in total", total)
- for r.Len() != 0 {
- switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
- case nil:
- t.Logf("wrote %d bytes", n)
- total += n
- case *tcpip.ErrWouldBlock:
- for {
- t.Logf("waiting on server")
- select {
- case <-serverCh:
- case <-time.After(time.Second):
- if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 {
- t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness)
- }
- continue
- }
- break
- }
- default:
- return fmt.Errorf("server.Write failed: %s", err)
- }
- }
- if err := server.Shutdown(tcpip.ShutdownWrite); err != nil {
- return fmt.Errorf("server.Shutdown failed: %s", err)
- }
- t.Logf("server end shutdown done")
- return nil
- }(); err != nil {
- t.Error(err)
- }
- }()
-
- go func() {
- defer wg.Done()
-
- if err := func() error {
- total := 0
- defer t.Logf("read %d bytes in total", total)
- for {
- switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
- case nil:
- t.Logf("read %d bytes", res.Count)
- total += res.Count
- t.Logf("read total %d bytes till now", total)
- case *tcpip.ErrClosedForReceive:
- return nil
- case *tcpip.ErrWouldBlock:
- for {
- t.Logf("waiting on client")
- select {
- case <-clientCh:
- case <-time.After(time.Second):
- if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 {
- return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness)
- }
- continue
- }
- break
- }
- default:
- return fmt.Errorf("client.Write failed: %s", err)
- }
- }
- }(); err != nil {
- t.Error(err)
- }
- }()
- })
- }
-}
-
-// Test the stack receive window advertisement on receiving segments smaller than
-// segment overhead. It tests for the right edge of the window to not grow when
-// the endpoint is not being read from.
-func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- opt := tcpip.TCPReceiveBufferSizeRangeOption{
- Min: 1,
- Default: tcp.DefaultReceiveBufferSize,
- Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
- }
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
-
- c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Bump up the receive buffer size such that, when the receive window grows,
- // the scaled window exceeds maxUint16.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */)
-
- // Keep the payload size < segment overhead and such that it is a multiple
- // of the window scaled value. This enables the test to perform equality
- // checks on the incoming receive window.
- payloadSize := 1 << c.RcvdWindowScale
- if payloadSize >= tcp.SegSize {
- t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize)
- }
- payload := generateRandomPayload(t, payloadSize)
- payloadLen := seqnum.Size(len(payload))
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
-
- // Send payload to the endpoint and return the advertised receive window
- // from the endpoint.
- getIncomingRcvWnd := func() uint32 {
- c.SendPacket(payload, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- Flags: header.TCPFlagAck,
- RcvWnd: 30000,
- })
- iss = iss.Add(payloadLen)
-
- pkt := c.GetPacket()
- return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
- }
-
- // Read the advertised receive window with the ACK for payload.
- rcvWnd := getIncomingRcvWnd()
-
- // Check if the subsequent ACK to our send has not grown the right edge of
- // the window.
- if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
- t.Fatalf("got incomingRcvwnd %d want %d", got, want)
- }
-
- // Read the data so that the subsequent ACK from the endpoint
- // grows the right edge of the window.
- var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil {
- t.Fatalf("c.EP.Read: %s", err)
- }
-
- // Check if we have received max uint16 as our advertised
- // scaled window now after a read above.
- maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
- if got, want := getIncomingRcvWnd(), maxRcv; got != want {
- t.Fatalf("got incomingRcvwnd %d want %d", got, want)
- }
-
- // Check if the subsequent ACK to our send has not grown the right edge of
- // the window.
- if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
- t.Fatalf("got incomingRcvwnd %d want %d", got, want)
- }
-}
-
-func TestNoWindowShrinking(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Start off with a certain receive buffer then cut it in half and verify that
- // the right edge of the window does not shrink.
- // NOTE: Netstack doubles the value specified here.
- rcvBufSize := 65536
- // Enable window scaling with a scale of zero from our end.
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send a 1 byte payload so that we can record the current receive window.
- // Send a payload of half the size of rcvBufSize.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- payload := []byte{1}
- c.SendPacket(payload, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Read the 1 byte payload we just sent.
- if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) {
- t.Fatalf("got data: %v, want: %v", got, want)
- }
-
- // Verify that the ACK does not shrink the window.
- pkt := c.GetPacket()
- iss = iss.Add(1)
- checker.IPv4(t, pkt,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Stash the initial window.
- initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
- // Now shrink the receive buffer to half its original size.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */)
-
- data := generateRandomPayload(t, rcvBufSize)
- // Send a payload of half the size of rcvBufSize.
- c.SendPacket(data[:rcvBufSize/2], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- iss = iss.Add(seqnum.Size(rcvBufSize / 2))
-
- // Verify that the ACK does not shrink the window.
- pkt = c.GetPacket()
- checker.IPv4(t, pkt,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd))
- if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
- t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
- }
-
- // Send another payload of half the size of rcvBufSize. This should fill up the
- // socket receive buffer and we should see a zero window.
- c.SendPacket(data[rcvBufSize/2:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- iss = iss.Add(seqnum.Size(rcvBufSize / 2))
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(0),
- ),
- )
-
- // Receive data and check it.
- read := ept.CheckReadFull(t, len(data), ch, 5*time.Second)
- if !bytes.Equal(data, read) {
- t.Fatalf("got data = %v, want = %v", read, data)
- }
-
- // Check that we get an ACK for the newly non-zero window, which is the new
- // receive buffer size we set after the connection was established.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
- ),
- )
-}
-
-func TestSimpleSend(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Fatalf("got data = %v, want = %v", p, data)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
- RcvWnd: 30000,
- })
-}
-
-func TestZeroWindowSend(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check if we got a zero-window probe.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Open up the window. Data should be received now.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that data is received.
- b = c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Fatalf("got data = %v, want = %v", p, data)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
- RcvWnd: 30000,
- })
-}
-
-func TestScaledWindowConnect(t *testing.T) {
- // This test ensures that window scaling is used when the peer
- // does advertise it and connection is established with Connect().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the window size greater than the maximum non-scaled window.
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0x5fff,
- // that is, that it is scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestNonScaledWindowConnect(t *testing.T) {
- // This test ensures that window scaling is not used when the peer
- // doesn't advertise it and connection is established with Connect().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the window size greater than the maximum non-scaled window.
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3)
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0xffff,
- // that is, that it's not scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestScaledWindowAccept(t *testing.T) {
- // This test ensures that window scaling is used when the peer
- // does advertise it and connection is established with Accept().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- // Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake.
- // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
- c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0x5fff,
- // that is, that it is scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestNonScaledWindowAccept(t *testing.T) {
- // This test ensures that window scaling is not used when the peer
- // doesn't advertise it and connection is established with Accept().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- // Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
- // should not carry the window scaling option.
- c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0xffff,
- // that is, that it's not scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestZeroScaledWindowReceive(t *testing.T) {
- // This test ensures that the endpoint sends a non-zero window size
- // advertisement when the scaled window transitions from 0 to non-zero,
- // but the actual window (not scaled) hasn't gotten to zero.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the buffer size such that a window scale of 5 will be used.
- const bufSz = 65535 * 10
- const ws = uint32(5)
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- // Write chunks of 50000 bytes.
- remain := 0
- sent := 0
- data := make([]byte, 50000)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- // Keep writing till the window drops below len(data).
- for {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Don't reduce window to zero here.
- if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) {
- remain = wnd << ws
- break
- }
- }
-
- // Make the window non-zero, but the scaled window zero.
- for remain >= 16 {
- data = data[:remain-15]
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Since the receive buffer is split between window advertisement and
- // application data buffer the window does not always reflect the space
- // available and actual space available can be a bit more than what is
- // advertised in the window.
- wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize())
- if wnd == 0 {
- break
- }
- remain = wnd << ws
- }
-
- // Read at least 2MSS of data. An ack should be sent in response to that.
- // Since buffer space is now split in half between window and application
- // data we need to read more than 1 MSS(65536) of data for a non-zero window
- // update to be sent. For 1MSS worth of window to be available we need to
- // read at least 128KB. Since our segments above were 50KB each it means
- // we need to read at 3 packets.
- w := tcpip.LimitedWriter{
- W: ioutil.Discard,
- N: defaultMTU * 2,
- }
- for w.N != 0 {
- res, err := c.EP.Read(&w, tcpip.ReadOptions{})
- t.Logf("err=%v res=%#v", err, res)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestSegmentMerging(t *testing.T) {
- tests := []struct {
- name string
- stop func(tcpip.Endpoint)
- resume func(tcpip.Endpoint)
- }{
- {
- "stop work",
- func(ep tcpip.Endpoint) {
- ep.(interface{ StopWork() }).StopWork()
- },
- func(ep tcpip.Endpoint) {
- ep.(interface{ ResumeWork() }).ResumeWork()
- },
- },
- {
- "cork",
- func(ep tcpip.Endpoint) {
- ep.SocketOptions().SetCorkOption(true)
- },
- func(ep tcpip.Endpoint) {
- ep.SocketOptions().SetCorkOption(false)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Send tcp.InitialCwnd number of segments to fill up
- // InitialWindow but don't ACK. That should prevent
- // anymore packets from going out.
- var r bytes.Reader
- for i := 0; i < tcp.InitialCwnd; i++ {
- r.Reset([]byte{0})
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- // Now send the segments that should get merged as the congestion
- // window is full and we won't be able to send any more packets.
- var allData []byte
- for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
- allData = append(allData, data...)
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- // Check that we get tcp.InitialCwnd packets.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := 0; i < tcp.InitialCwnd; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(header.TCPMinimumSize+1),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
- RcvWnd: 30000,
- })
-
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(allData)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+11),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) {
- t.Fatalf("got data = %v, want = %v", got, allData)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
- RcvWnd: 30000,
- })
- })
- }
-}
-
-func TestDelay(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- c.EP.SocketOptions().SetDelayOption(true)
-
- var allData []byte
- for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
- allData = append(allData, data...)
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for _, want := range [][]byte{allData[:1], allData[1:]} {
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(want)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) {
- t.Fatalf("got data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(want)))
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seq,
- RcvWnd: 30000,
- })
- }
-}
-
-func TestUndelay(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- c.EP.SocketOptions().SetDelayOption(true)
-
- allData := [][]byte{{0}, {1, 2, 3}}
- for i, data := range allData {
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- // Check that data is received.
- first := c.GetPacket()
- checker.IPv4(t, first,
- checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) {
- t.Fatalf("got first packet's data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(allData[0])))
-
- // Check that we don't get the second packet yet.
- c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
-
- c.EP.SocketOptions().SetDelayOption(false)
-
- // Check that data is received.
- second := c.GetPacket()
- checker.IPv4(t, second,
- checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) {
- t.Fatalf("got second packet's data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(allData[1])))
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seq,
- RcvWnd: 30000,
- })
-}
-
-func TestMSSNotDelayed(t *testing.T) {
- tests := []struct {
- name string
- fn func(tcpip.Endpoint)
- }{
- {"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }},
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
-
- test.fn(c.EP)
-
- allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
- for i, data := range allData {
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i, data := range allData {
- // Check that data is received.
- packet := c.GetPacket()
- checker.IPv4(t, packet,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) {
- t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(data)))
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seq,
- RcvWnd: 30000,
- })
- })
- }
-}
-
-func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
- payloadMultiplier := 10
- dataLen := payloadMultiplier * maxPayload
- data := make([]byte, dataLen)
- for i := range data {
- data[i] = byte(i)
- }
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received in chunks.
- bytesReceived := 0
- numPackets := 0
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for bytesReceived != dataLen {
- b := c.GetPacket()
- numPackets++
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- payloadLen := len(tcpHdr.Payload())
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- pdata := data[bytesReceived : bytesReceived+payloadLen]
- if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
- t.Fatalf("got data = %v, want = %v", p, pdata)
- }
- bytesReceived += payloadLen
- var options []byte
- if c.TimeStampEnabled {
- // If timestamp option is enabled, echo back the timestamp and increment
- // the TSEcr value included in the packet and send that back as the TSVal.
- parsedOpts := tcpHdr.ParsedOptions()
- tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
- options = tsOpt[:]
- }
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
- RcvWnd: 30000,
- TCPOpts: options,
- })
- }
- if numPackets == 1 {
- t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet")
- }
-}
-
-func TestSendGreaterThanMTU(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSetTTL(t *testing.T) {
- for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
- t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
- c := context.New(t, 65535)
- defer c.Cleanup()
-
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
- t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
- }
-
- {
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
-
- checker.IPv4(t, b, checker.TTL(wantTTL))
- })
- }
-}
-
-func TestActiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, 65535)
- defer c.Cleanup()
-
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestPassiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- const mtu = 1200
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- // Set the buffer size to a deterministic size so that we can check the
- // window scaling option.
- const rcvBufferSize = 0x20000
- ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 536
- const mtu = 2000
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestForwarderSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- const mtu = 1200
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- s := c.Stack()
- ch := make(chan tcpip.Error, 1)
- f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
- var err tcpip.Error
- c.EP, err = r.CreateEndpoint(&c.WQ)
- ch <- err
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Wait for connection to be available.
- select {
- case err := <-ch:
- if err != nil {
- t.Fatalf("Error creating endpoint: %s", err)
- }
- case <-time.After(2 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSynOptionsOnActiveConnect(t *testing.T) {
- const mtu = 1400
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Set the buffer size to a deterministic size so that we can check the
- // window scaling option.
- const rcvBufferSize = 0x20000
- const wndScale = 3
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
-
- // Start connection attempt.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&we)
-
- {
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
- ),
- )
-
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- // Wait for retransmit.
- time.Sleep(1 * time.Second)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.SrcPort(tcpHdr.SourcePort()),
- checker.TCPSeqNum(tcpHdr.SequenceNumber()),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
- ),
- )
-
- // Send SYN-ACK.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ),
- )
-
- // Wait for connection to be established.
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Connect failed: %s", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestCloseListener(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create listener.
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Close the listener and measure how long it takes.
- t0 := time.Now()
- ep.Close()
- if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %s", diff)
- }
-}
-
-func TestReceiveOnResetConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- // Send RST segment.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Try to read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
-loop:
- for {
- switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
- case *tcpip.ErrWouldBlock:
- <-ch
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
- }
- break loop
- case *tcpip.ErrConnectionReset:
- break loop
- default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
- }
- }
-
- if tcp.EndpointState(c.EP.State()) != tcp.StateError {
- t.Fatalf("got EP state is not StateError")
- }
-
- checkValid := func() []error {
- var errors []error
- if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got))
- }
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got))
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got))
- }
- return errors
- }
-
- start := time.Now()
- for time.Since(start) < time.Minute && len(checkValid()) > 0 {
- time.Sleep(50 * time.Millisecond)
- }
- for _, err := range checkValid() {
- t.Error(err)
- }
-}
-
-func TestSendOnResetConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Send RST segment.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Wait for the RST to be received.
- time.Sleep(1 * time.Second)
-
- // Try to write.
- var r bytes.Reader
- r.Reset(make([]byte, 10))
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d)
- }
-}
-
-// TestMaxRetransmitsTimeout tests if the connection is timed out after
-// a segment has been retransmitted MaxRetries times.
-func TestMaxRetransmitsTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const numRetries = 2
- opt := tcpip.TCPMaxRetriesOption(numRetries)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- // Wait for the connection to timeout after MaxRetries retransmits.
- initRTO := time.Second
- minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- var r bytes.Reader
- r.Reset(make([]byte, 1))
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Expect first transmit and MaxRetries retransmits.
- for i := 0; i < numRetries+1; i++ {
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
- ),
- )
- }
- select {
- case <-notifyCh:
- case <-time.After((2 << numRetries) * initRTO):
- t.Fatalf("connection still alive after maximum retransmits.\n")
- }
-
- // Send an ACK and expect a RST as the connection would have been closed.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-// TestMaxRTO tests if the retransmit interval caps to MaxRTO.
-func TestMaxRTO(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- rto := 1 * time.Second
- minRTOOpt := tcpip.TCPMinRTOOption(rto / 2)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- maxRTOOpt := tcpip.TCPMaxRTOOption(rto)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- var r bytes.Reader
- r.Reset(make([]byte, 1))
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- const numRetransmits = 2
- for i := 0; i < numRetransmits; i++ {
- start := time.Now()
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() {
- t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed)
- }
- }
-}
-
-// TestZeroSizedWriteRetransmit tests that a zero sized write should not
-// result in a panic on an RTO as no segment should have been queued for
-// a zero sized write.
-func TestZeroSizedWriteRetransmit(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- var r bytes.Reader
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- // Now do a non-zero sized write to trigger actual sending of data.
- r.Reset(make([]byte, 1))
- _, err = c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- // Do not ACK the packet and expect an original transmit and a
- // retransmit. This should not cause a panic.
- for i := 0; i < 2; i++ {
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- }
-}
-
-// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
-// unique on retransmits.
-func TestRetransmitIPv4IDUniqueness(t *testing.T) {
- for _, tc := range []struct {
- name string
- size int
- }{
- {"1Byte", 1},
- {"512Bytes", 512},
- } {
- t.Run(tc.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- minRTOOpt := tcpip.TCPMinRTOOption(time.Second)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- // Disabling PMTU discovery causes all packets sent from this socket to
- // have DF=0. This needs to be done because the IPv4 ID uniqueness
- // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
- // Section 4, and datagrams with DF=0 are non-atomic.
- if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
- t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
- }
-
- var r bytes.Reader
- r.Reset(make([]byte, tc.size))
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.FragmentFlags(0),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}}
- // Expect two retransmitted packets, and that all packets received have
- // unique IPv4 ID values.
- for i := 0; i <= 2; i++ {
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.FragmentFlags(0),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- id := header.IPv4(pkt).ID()
- if _, exists := idSet[id]; exists {
- t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
- }
- idSet[id] = struct{}{}
- }
- })
- }
-}
-
-func TestFinImmediately(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Shutdown immediately, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinRetransmit(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Shutdown immediately, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Don't acknowledge yet. We should get a retransmit of the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithNoPendingData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and have it acknowledged.
- view := make([]byte, 10)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Shutdown, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPendingDataCwndFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write enough segments to fill the congestion window before ACK'ing
- // any of them.
- view := make([]byte, 10)
- var r bytes.Reader
- for i := tcp.InitialCwnd; i > 0; i-- {
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := tcp.InitialCwnd; i > 0; i-- {
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
- }
-
- // Shutdown the connection, check that the FIN segment isn't sent
- // because the congestion window doesn't allow it. Wait until a
- // retransmit is received.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Send the ACK that will allow the FIN to be sent as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send a FIN that acknowledges everything. Get an ACK back.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPendingData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and acknowledge it to get cwnd to 2.
- view := make([]byte, 10)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Write new data, but don't acknowledge it.
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- // Shutdown the connection, check that we do get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send a FIN that acknowledges everything. Get an ACK back.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPartialAck(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and acknowledge it to get cwnd to 2. Also send
- // FIN from the test side.
- view := make([]byte, 10)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Check that we get an ACK for the fin.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Write new data, but don't acknowledge it.
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- // Shutdown the connection, check that we do get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send an ACK for the data, but not for the FIN yet.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(1),
- AckNum: seqnum.Value(next - 1),
- RcvWnd: 30000,
- })
-
- // Check that we don't get a retransmit of the FIN.
- c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond)
-
- // Ack the FIN.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss.Add(1),
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-}
-
-func TestUpdateListenBacklog(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create listener.
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Update the backlog with another Listen() on the same endpoint.
- if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %s", err)
- }
-
- ep.Close()
-}
-
-func scaledSendWindow(t *testing.T, scale uint8) {
- // This test ensures that the endpoint is using the right scaling by
- // sending a buffer that is larger than the window size, and ensuring
- // that the endpoint doesn't send more than allowed.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
- })
-
- // Open up the window with a scaled value.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 1,
- })
-
- // Send some data. Check that it's capped by the window size.
- view := make([]byte, 65535)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that only data that fits in the scaled window is sent.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Reset the connection to free resources.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: iss,
- })
-}
-
-func TestScaledSendWindow(t *testing.T) {
- for scale := uint8(0); scale <= 14; scale++ {
- scaledSendWindow(t, scale)
- }
-}
-
-func TestReceivedValidSegmentCountIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.ValidSegmentsReceived.Value() + 1
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
- }
- // Ensure there were no errors during handshake. If these stats have
- // incremented, then the connection should not have been established.
- if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
- }
-}
-
-func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.InvalidSegmentsReceived.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- vv := c.BuildSegment(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
- tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
-
- c.SendSegment(vv)
-
- if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
- }
-}
-
-func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.ChecksumErrors.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
- // Overwrite a byte in the payload which should cause checksum
- // verification to fail.
- tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
-
- c.SendSegment(vv)
-
- if got := stats.TCP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
- }
-}
-
-func TestReceivedSegmentQueuing(t *testing.T) {
- // This test sends 200 segments containing a few bytes each to an
- // endpoint and checks that they're all received and acknowledged by
- // the endpoint, that is, that none of the segments are dropped by
- // internal queues.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Send 200 segments.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- data := []byte{1, 2, 3}
- for i := 0; i < 200; i++ {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(i * len(data))),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- }
-
- // Receive ACKs for all segments.
- last := iss.Add(seqnum.Size(200 * len(data)))
- for {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- ack := seqnum.Value(tcpHdr.AckNumber())
- if ack == last {
- break
- }
-
- if last.LessThan(ack) {
- t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last)
- }
- }
-}
-
-func TestReadAfterClosedState(t *testing.T) {
- // This test ensures that calling Read() or Peek() after the endpoint
- // has transitioned to closedState still works if there is pending
- // data. To transition to stateClosed without calling Close(), we must
- // shutdown the send path and the peer must send its own FIN.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
- // after 1 second in TIME_WAIT state.
- tcpTimeWaitTimeout := 1 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Shutdown immediately for write, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Send some data and acknowledge the FIN.
- data := []byte{1, 2, 3}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that ACK is received.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Give the stack the chance to transition to closed state from
- // TIME_WAIT.
- time.Sleep(tcpTimeWaitTimeout * 2)
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that peek works.
- var peekBuf bytes.Buffer
- res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true})
- if err != nil {
- t.Fatalf("Peek failed: %s", err)
- }
-
- if got, want := res.Count, len(data); got != want {
- t.Fatalf("res.Count = %d, want %d", got, want)
- }
- if !bytes.Equal(data, peekBuf.Bytes()) {
- t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data)
- }
-
- // Receive data.
- v := ept.CheckRead(t)
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- // Now that we drained the queue, check that functions fail with the
- // right error code.
- ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
- var buf bytes.Buffer
- {
- _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true})
- if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" {
- t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d)
- }
- }
-}
-
-func TestReusePort(t *testing.T) {
- // This test ensures that ports are immediately available for reuse
- // after Close on the endpoints using them returns.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // First case, just an endpoint that was bound.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- c.EP.Close()
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- c.EP.Close()
-
- // Second case, an endpoint that was bound and is connecting..
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- {
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
- c.EP.Close()
-
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- c.EP.Close()
-
- // Third case, an endpoint that was bound and is listening.
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- c.EP.Close()
-
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-}
-
-func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
- t.Helper()
-
- s := ep.SocketOptions().GetReceiveBufferSize()
- if int(s) != v {
- t.Fatalf("got receive buffer size = %d, want = %d", s, v)
- }
-}
-
-func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
- t.Helper()
-
- if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v {
- t.Fatalf("got send buffer size = %d, want = %d", s, v)
- }
-}
-
-func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- // Check the default values.
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer func() {
- if ep != nil {
- ep.Close()
- }
- }()
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
-
- // Change the default send buffer size.
- {
- opt := tcpip.TCPSendBufferSizeRangeOption{
- Min: 1,
- Default: tcp.DefaultSendBufferSize * 2,
- Max: tcp.DefaultSendBufferSize * 20,
- }
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- ep.Close()
- ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
-
- // Change the default receive buffer size.
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{
- Min: 1,
- Default: tcp.DefaultReceiveBufferSize * 3,
- Max: tcp.DefaultReceiveBufferSize * 30,
- }
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- ep.Close()
- ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
-}
-
-func TestBindToDeviceOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}})
-
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer ep.Close()
-
- if err := s.CreateNIC(321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %s", err)
- }
-
- // nicIDPtr is used instead of taking the address of NICID literals, which is
- // a compiler error.
- nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
- return &s
- }
-
- testActions := []struct {
- name string
- setBindToDevice *tcpip.NICID
- setBindToDeviceError tcpip.Error
- getBindToDevice int32
- }{
- {"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
- {"BindToExistent", nicIDPtr(321), nil, 321},
- {"UnbindToDevice", nicIDPtr(0), nil, 0},
- }
- for _, testAction := range testActions {
- t.Run(testAction.name, func(t *testing.T) {
- if testAction.setBindToDevice != nil {
- bindToDevice := int32(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
- }
- }
- bindToDevice := ep.SocketOptions().GetBindToDevice()
- if bindToDevice != testAction.getBindToDevice {
- t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
- }
- })
- }
-}
-
-func makeStack() (*stack.Stack, tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocol,
- ipv6.NewProtocol,
- },
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- id := loopback.New()
- if testing.Verbose() {
- id = sniffer.New(id)
- }
-
- if err := s.CreateNIC(1, id); err != nil {
- return nil, err
- }
-
- for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- addrWithPrefix tcpip.AddressWithPrefix
- }{
- {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
- {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
- } {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ct.number,
- AddressWithPrefix: ct.addrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- return nil, err
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- return s, nil
-}
-
-func TestSelfConnect(t *testing.T) {
- // This test ensures that intentional self-connects work. In particular,
- // it checks that if an endpoint binds to say 127.0.0.1:1000 then
- // connects to 127.0.0.1:1000, then it will be connected to itself, and
- // is able to send and receive data through the same endpoint.
- s, err := makeStack()
- if err != nil {
- t.Fatal(err)
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Register for notification, then start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.WritableEvents)
- defer wq.EventUnregister(&waitEntry)
-
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- <-notifyCh
- if err := ep.LastError(); err != nil {
- t.Fatalf("Connect failed: %s", err)
- }
-
- // Write something.
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Read back what was written.
- wq.EventUnregister(&waitEntry)
- wq.EventRegister(&waitEntry, waiter.ReadableEvents)
- ept := endpointTester{ep}
- rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second)
-
- if !bytes.Equal(data, rd) {
- t.Fatalf("got data = %v, want = %v", rd, data)
- }
-}
-
-func TestConnectAvoidsBoundPorts(t *testing.T) {
- addressTypes := func(t *testing.T, network string) []string {
- switch network {
- case "ipv4":
- return []string{"v4"}
- case "ipv6":
- return []string{"v6"}
- case "dual":
- return []string{"v6", "mapped"}
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
-
- panic("unreachable")
- }
-
- address := func(t *testing.T, addressType string, isAny bool) tcpip.Address {
- switch addressType {
- case "v4":
- if isAny {
- return ""
- }
- return context.StackAddr
- case "v6":
- if isAny {
- return ""
- }
- return context.StackV6Addr
- case "mapped":
- if isAny {
- return context.V4MappedWildcardAddr
- }
- return context.StackV4MappedAddr
- default:
- t.Fatalf("unknown address type: '%s'", addressType)
- }
-
- panic("unreachable")
- }
- // This test ensures that Endpoint.Connect doesn't select already-bound ports.
- networks := []string{"ipv4", "ipv6", "dual"}
- for _, exhaustedNetwork := range networks {
- t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) {
- for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) {
- t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) {
- for _, isAny := range []bool{false, true} {
- t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) {
- for _, candidateNetwork := range networks {
- t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) {
- for _, candidateAddressType := range addressTypes(t, candidateNetwork) {
- t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) {
- s, err := makeStack()
- if err != nil {
- t.Fatal(err)
- }
-
- var wq waiter.Queue
- var eps []tcpip.Endpoint
- defer func() {
- for _, ep := range eps {
- ep.Close()
- }
- }()
- makeEP := func(network string) tcpip.Endpoint {
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch network {
- case "ipv4":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "ipv6", "dual":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- eps = append(eps, ep)
- switch network {
- case "ipv4":
- case "ipv6":
- ep.SocketOptions().SetV6Only(true)
- case "dual":
- ep.SocketOptions().SetV6Only(false)
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
- return ep
- }
-
- var v4reserved, v6reserved bool
- switch exhaustedAddressType {
- case "v4", "mapped":
- v4reserved = true
- case "v6":
- v6reserved = true
- // Dual stack sockets bound to v6 any reserve on v4 as
- // well.
- if isAny {
- switch exhaustedNetwork {
- case "ipv6":
- case "dual":
- v4reserved = true
- default:
- t.Fatalf("unknown address type: '%s'", exhaustedNetwork)
- }
- }
- default:
- t.Fatalf("unknown address type: '%s'", exhaustedAddressType)
- }
- var collides bool
- switch candidateAddressType {
- case "v4", "mapped":
- collides = v4reserved
- case "v6":
- collides = v6reserved
- default:
- t.Fatalf("unknown address type: '%s'", candidateAddressType)
- }
-
- const (
- start = 16000
- end = 16050
- )
- if err := s.SetPortRange(start, end); err != nil {
- t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
- }
- for i := start; i <= end; i++ {
- if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %s", i, err)
- }
- }
- var want tcpip.Error = &tcpip.ErrConnectStarted{}
- if collides {
- want = &tcpip.ErrNoPortAvailable{}
- }
- if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
- }
- })
- }
- })
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestPathMTUDiscovery(t *testing.T) {
- // This test verifies the stack retransmits packets after it receives an
- // ICMP packet indicating that the path MTU has been exceeded.
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- // Create new connection with MSS of 1460.
- const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
-
- // Send 3200 bytes of data.
- const writeSize = 3200
- data := make([]byte, writeSize)
- for i := range data {
- data[i] = byte(i)
- }
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
- var ret []byte
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i, size := range sizes {
- p := c.GetPacket()
- if i == which {
- ret = p
- }
- checker.IPv4(t, p,
- checker.PayloadLen(size+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(seqNum),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- seqNum += uint32(size)
- }
- return ret
- }
-
- // Receive three packets.
- sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload}
- first := receivePackets(c, sizes, 0, uint32(c.IRS)+1)
-
- // Send "packet too big" messages back to netstack.
- const newMTU = 1200
- const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- mtu := []byte{0, 0, newMTU / 256, newMTU % 256}
- c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU)
-
- // See retransmitted packets. None exceeding the new max.
- sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload}
- receivePackets(c, sizes, -1, uint32(c.IRS)+1)
-}
-
-func TestTCPEndpointProbe(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- invoked := make(chan struct{})
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint ID is what we expect.
- //
- // We don't do an extensive validation of every field but a
- // basic sanity test.
- if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want {
- t.Fatalf("got LocalAddress: %q, want: %q", got, want)
- }
- if got, want := state.ID.LocalPort, c.Port; got != want {
- t.Fatalf("got LocalPort: %d, want: %d", got, want)
- }
- if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want {
- t.Fatalf("got RemoteAddress: %q, want: %q", got, want)
- }
- if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want {
- t.Fatalf("got RemotePort: %d, want: %d", got, want)
- }
-
- invoked <- struct{}{}
- })
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- select {
- case <-invoked:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("TCP Probe function was not called")
- }
-}
-
-func TestStackSetCongestionControl(t *testing.T) {
- testCases := []struct {
- cc tcpip.CongestionControlOption
- err tcpip.Error
- }{
- {"reno", nil},
- {"cubic", nil},
- {"blahblah", &tcpip.ErrNoSuchFile{}},
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- var oldCC tcpip.CongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
- }
-
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err)
- }
-
- var cc tcpip.CongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
- }
-
- got, want := cc, oldCC
- // If SetTransportProtocolOption is expected to succeed
- // then the returned value for congestion control should
- // match the one specified in the
- // SetTransportProtocolOption call above, else it should
- // be what it was before the call to
- // SetTransportProtocolOption.
- if tc.err == nil {
- want = tc.cc
- }
- if got != want {
- t.Fatalf("got congestion control: %v, want: %v", got, want)
- }
- })
- }
-}
-
-func TestStackAvailableCongestionControl(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- // Query permitted congestion control algorithms.
- var aCC tcpip.TCPAvailableCongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
- }
- if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want)
- }
-}
-
-func TestStackSetAvailableCongestionControl(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- // Setting AvailableCongestionControlOption should fail.
- aCC := tcpip.TCPAvailableCongestionControlOption("xyz")
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC)
- }
-
- // Verify that we still get the expected list of congestion control options.
- var cc tcpip.TCPAvailableCongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err)
- }
- if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want)
- }
-}
-
-func TestEndpointSetCongestionControl(t *testing.T) {
- testCases := []struct {
- cc tcpip.CongestionControlOption
- err tcpip.Error
- }{
- {"reno", nil},
- {"cubic", nil},
- {"blahblah", &tcpip.ErrNoSuchFile{}},
- }
-
- for _, connected := range []bool{false, true} {
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- var oldCC tcpip.CongestionControlOption
- if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err)
- }
-
- if connected {
- c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil)
- }
-
- if err := c.EP.SetSockOpt(&tc.cc); err != tc.err {
- t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err)
- }
-
- var cc tcpip.CongestionControlOption
- if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err)
- }
-
- got, want := cc, oldCC
- // If SetSockOpt is expected to succeed then the
- // returned value for congestion control should match
- // the one specified in the SetSockOpt above, else it
- // should be what it was before the call to SetSockOpt.
- if tc.err == nil {
- want = tc.cc
- }
- if got != want {
- t.Fatalf("got congestion control = %+v, want = %+v", got, want)
- }
- })
- }
- }
-}
-
-func enableCUBIC(t *testing.T, c *context.Context) {
- t.Helper()
- opt := tcpip.CongestionControlOption("cubic")
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err)
- }
-}
-
-func TestKeepalive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- const keepAliveIdle = 100 * time.Millisecond
- const keepAliveInterval = 3 * time.Second
- keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle)
- if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err)
- }
- keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval)
- if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err)
- }
- c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
- if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
- t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
- }
- c.EP.SocketOptions().SetKeepAlive(true)
-
- // 5 unacked keepalives are sent. ACK each one, and check that the
- // connection stays alive after 5.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := 0; i < 10; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Acknowledge the keepalive.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS,
- RcvWnd: 30000,
- })
- }
-
- // Check that the connection is still alive.
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send some data and wait before ACKing it. Keepalives should be disabled
- // during this period.
- view := make([]byte, 3)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Wait for the packet to be retransmitted. Verify that no keepalives
- // were sent.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
- ),
- )
- c.CheckNoPacket("Keepalive packet received while unACKed data is pending")
-
- next += uint32(len(view))
-
- // Send ACK. Keepalives should start sending again.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Now receive 5 keepalives, but don't ACK them. The connection
- // should be reset after 5.
- for i := 0; i < 5; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next-1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Sleep for a litte over the KeepAlive interval to make sure
- // the timer has time to fire after the last ACK and close the
- // close the socket.
- time.Sleep(keepAliveInterval + keepAliveInterval/2)
-
- // The connection should be terminated after 5 unacked keepalives.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)),
- )
-
- if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
- }
-
- ept.CheckReadError(t, &tcpip.ErrTimeout{})
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
- t.Helper()
- // Send a SYN request.
- irs = seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss = seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
-
- if synCookieInUse {
- // When cookies are in use window scaling is disabled.
- tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
- WS: -1,
- MSS: c.MSSWithoutOptions(),
- }))
- }
-
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- return irs, iss
-}
-
-func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
- t.Helper()
- // Send a SYN request.
- irs = seqnum.Value(context.TestInitialSequenceNumber)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetV6Packet()
- tcp := header.TCP(header.IPv6(b).Payload())
- iss = seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
-
- if synCookieInUse {
- // When cookies are in use window scaling is disabled.
- tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
- WS: -1,
- MSS: c.MSSWithoutOptionsV6(),
- }))
- }
-
- checker.IPv6(t, b, checker.TCP(tcpCheckers...))
-
- // Send ACK.
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- return irs, iss
-}
-
-// TestListenBacklogFull tests that netstack does not complete handshakes if the
-// listen backlog for the endpoint is full.
-func TestListenBacklogFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- // Start listening.
- listenBacklog := 10
- if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- lastPortOffset := uint16(0)
- for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
- executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
- }
-
- time.Sleep(50 * time.Millisecond)
-
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + lastPortOffset,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- for i := 0; i < listenBacklog; i++ {
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- }
-
- // Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept(nil)
- if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- select {
- case <-ch:
- t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
- case <-time.After(1 * time.Second):
- }
- }
-
- // Now a new handshake must succeed.
- executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
-
- newEP, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- var r strings.Reader
- r.Reset(data)
- newEP.Write(&r, tcpip.WriteOptions{})
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- if string(tcp.Payload()) != data {
- t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
- }
-}
-
-// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
-// non unicast IPv4 address are not accepted.
-func TestListenNoAcceptNonUnicastV4(t *testing.T) {
- multicastAddr := tcpiptestutil.MustParse4("224.0.1.2")
- otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3")
- subnet := context.StackAddrWithPrefix.Subnet()
- subnetBroadcastAddr := subnet.Broadcast()
-
- tests := []struct {
- name string
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- }{
- {
- name: "SourceUnspecified",
- srcAddr: header.IPv4Any,
- dstAddr: context.StackAddr,
- },
- {
- name: "SourceBroadcast",
- srcAddr: header.IPv4Broadcast,
- dstAddr: context.StackAddr,
- },
- {
- name: "SourceOurMulticast",
- srcAddr: multicastAddr,
- dstAddr: context.StackAddr,
- },
- {
- name: "SourceOtherMulticast",
- srcAddr: otherMulticastAddr,
- dstAddr: context.StackAddr,
- },
- {
- name: "DestUnspecified",
- srcAddr: context.TestAddr,
- dstAddr: header.IPv4Any,
- },
- {
- name: "DestBroadcast",
- srcAddr: context.TestAddr,
- dstAddr: header.IPv4Broadcast,
- },
- {
- name: "DestOurMulticast",
- srcAddr: context.TestAddr,
- dstAddr: multicastAddr,
- },
- {
- name: "DestOtherMulticast",
- srcAddr: context.TestAddr,
- dstAddr: otherMulticastAddr,
- },
- {
- name: "SrcSubnetBroadcast",
- srcAddr: subnetBroadcastAddr,
- dstAddr: context.StackAddr,
- },
- {
- name: "DestSubnetBroadcast",
- srcAddr: context.TestAddr,
- dstAddr: subnetBroadcastAddr,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
- t.Fatalf("JoinGroup failed: %s", err)
- }
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, test.srcAddr, test.dstAddr)
- c.CheckNoPacket("Should not have received a response")
-
- // Handle normal packet.
- c.SendPacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, context.TestAddr, context.StackAddr)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1)))
- })
- }
-}
-
-// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
-// non unicast IPv6 address are not accepted.
-func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpiptestutil.MustParse6("ff0e::101")
- otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102")
-
- tests := []struct {
- name string
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- }{
- {
- "SourceUnspecified",
- header.IPv6Any,
- context.StackV6Addr,
- },
- {
- "SourceAllNodes",
- header.IPv6AllNodesMulticastAddress,
- context.StackV6Addr,
- },
- {
- "SourceOurMulticast",
- multicastAddr,
- context.StackV6Addr,
- },
- {
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackV6Addr,
- },
- {
- "DestUnspecified",
- context.TestV6Addr,
- header.IPv6Any,
- },
- {
- "DestAllNodes",
- context.TestV6Addr,
- header.IPv6AllNodesMulticastAddress,
- },
- {
- "DestOurMulticast",
- context.TestV6Addr,
- multicastAddr,
- },
- {
- "DestOtherMulticast",
- context.TestV6Addr,
- otherMulticastAddr,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
- t.Fatalf("JoinGroup failed: %s", err)
- }
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendV6PacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, test.srcAddr, test.dstAddr)
- c.CheckNoPacket("Should not have received a response")
-
- // Handle normal packet.
- c.SendV6PacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, context.TestV6Addr, context.StackV6Addr)
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1)))
- })
- }
-}
-
-func TestListenSynRcvdQueueFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send two SYN's the first one should get a SYN-ACK, the
- // second one should not get any response and is dropped as
- // the accept queue is full.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now complete the previous connection.
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Verify if that is delivered to the accept queue.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
- <-ch
-
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(889),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Try to accept the connections in the backlog.
- newEP, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- var r strings.Reader
- r.Reset(data)
- newEP.Write(&r, tcpip.WriteOptions{})
- pkt := c.GetPacket()
- tcp = header.IPv4(pkt).Payload()
- if string(tcp.Payload()) != data {
- t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
- }
-}
-
-func TestListenBacklogFullSynCookieInUse(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Test for SynCookies usage after filling up the backlog.
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- executeHandshake(t, c, context.TestPort, false)
-
- // Wait for this to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
-
- // Send a SYN request.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- // pick a different src port for new SYN.
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- // The Syn should be dropped as the endpoint's backlog is full.
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Verify that there is only one acceptable connection at this point.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept(nil)
- if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- select {
- case <-ch:
- t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
- case <-time.After(1 * time.Second):
- }
- }
-}
-
-func TestSYNRetransmit(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send the same SYN packet multiple times. We should still get a valid SYN-ACK
- // reply.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- for i := 0; i < 5; i++ {
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- }
-
- // Receive the SYN-ACK reply.
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
-}
-
-func TestSynRcvdBadSeqNumber(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcpHdr.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now send a packet with an out-of-window sequence number
- largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: largeSeqnum,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Should receive an ACK with the expected SEQ number
- b = c.GetPacket()
- tcpCheckers = []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPAckNum(uint32(irs) + 1),
- checker.TCPSeqNum(uint32(iss + 1)),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now that the socket replied appropriately with the ACK,
- // complete the connection to test that the large SEQ num
- // did not change the state from SYN-RCVD.
-
- // Get setup to be notified about connection establishment.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Send ACK to move to ESTABLISHED state.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- <-ch
- newEP, _, err := c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- var r strings.Reader
- r.Reset(data)
- if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- pkt := c.GetPacket()
- tcpHdr = header.IPv4(pkt).Payload()
- if string(tcpHdr.Payload()) != data {
- t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
- }
-}
-
-func TestPassiveConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
- if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- stats := c.Stack().Stats()
- want := stats.TCP.PassiveConnectionOpenings.Value() + 1
-
- srcPort := uint16(context.TestPort)
- executeHandshake(t, c, srcPort+1, false)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Verify that there is only one acceptable connection at this point.
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- srcPort := uint16(context.TestPort)
- // Now attempt a handshakes it will fill up the accept backlog.
- executeHandshake(t, c, srcPort, false)
-
- // Give time for the final ACK to be processed as otherwise the next handshake could
- // get accepted before the previous one based on goroutine scheduling.
- time.Sleep(50 * time.Millisecond)
-
- want := stats.TCP.ListenOverflowSynDrop.Value() + 1
-
- // Now we will send one more SYN and this one should get dropped
- // Send a SYN request.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort + 2,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
- RcvWnd: 30000,
- })
-
- checkValid := func() []error {
- var errors []error
- if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want))
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want))
- }
- return errors
- }
-
- start := time.Now()
- for time.Since(start) < time.Minute && len(checkValid()) > 0 {
- time.Sleep(50 * time.Millisecond)
- }
- for _, err := range checkValid() {
- t.Error(err)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Now check that there is one acceptable connections.
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- <-ch
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
- }
-}
-
-func TestListenDropIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- c.Create(-1 /*epRcvBuf*/)
-
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(1 /*backlog*/); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- initialDropped := stats.DroppedPackets.Value()
-
- // Send RST, FIN segments, that are expected to be dropped by the listener.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagRst,
- })
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin,
- })
-
- // To ensure that the RST, FIN sent earlier are indeed received and ignored
- // by the listener, send a SYN and wait for the SYN to be ACKd.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- })
- checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1),
- ))
-
- if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want {
- t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestEndpointBindListenAcceptState(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- ept := endpointTester{ep}
- ept.CheckReadError(t, &tcpip.ErrNotConnected{})
- if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
- t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- aep, _, err := ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- aep, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
- {
- err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" {
- t.Errorf("Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
- // Listening endpoint remains in listen state.
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- ep.Close()
- // Give worker goroutines time to receive the close notification.
- time.Sleep(1 * time.Second)
- if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
- // Accepted endpoint remains open when the listen endpoint is closed.
- if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
-}
-
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
-func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
- const mtu = 1500
- const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
-
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 500.
- const receiveBufferSize = 80 << 10 // 80KB.
- const maxReceiveBufferSize = receiveBufferSize * 10
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- // Enable auto-tuning.
- {
- opt := tcpip.TCPModerateReceiveBufferOption(true)
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- // Change the expected window scale to match the value needed for the
- // maximum buffer size defined above.
- c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
-
- rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
-
- // NOTE: The timestamp values in the sent packets are meaningless to the
- // peer so we just increment the timestamp value by 1 every batch as we
- // are not really using them for anything. Send a single byte to verify
- // the advertised window.
- tsVal := rawEP.TSVal + 1
-
- // Introduce a 25ms latency by delaying the first byte.
- latency := 25 * time.Millisecond
- time.Sleep(latency)
- // Send an initial payload with atleast segment overhead size. The receive
- // window would not grow for smaller segments.
- rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal)
-
- pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
- rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
-
- time.Sleep(25 * time.Millisecond)
-
- // Allocate a large enough payload for the test.
- payloadSize := receiveBufferSize * 2
- b := make([]byte, payloadSize)
-
- worker := (c.EP).(interface {
- StopWork()
- ResumeWork()
- })
- tsVal++
-
- // Stop the worker goroutine.
- worker.StopWork()
- start := 0
- end := payloadSize / 2
- packetsSent := 0
- for ; start < end; start += mss {
- packetEnd := start + mss
- if start+mss > end {
- packetEnd = end
- }
- rawEP.SendPacketWithTS(b[start:packetEnd], tsVal)
- packetsSent++
- }
-
- // Resume the worker so that it only sees the packets once all of them
- // are waiting to be read.
- worker.ResumeWork()
-
- // Since we sent almost the full receive buffer worth of data (some may have
- // been dropped due to segment overheads), we should get a zero window back.
- pkt = c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(pkt).Payload())
- gotRcvWnd := tcpHdr.WindowSize()
- wantAckNum := tcpHdr.AckNumber()
- if got, want := int(gotRcvWnd), 0; got != want {
- t.Fatalf("got rcvWnd: %d, want: %d", got, want)
- }
-
- time.Sleep(25 * time.Millisecond)
- // Verify that sending more data when receiveBuffer is exhausted.
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
-
- // Now read all the data from the endpoint and verify that advertised
- // window increases to the full available buffer size.
- for {
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- break
- }
- }
-
- // Verify that we receive a non-zero window update ACK. When running
- // under thread santizer this test can end up sending more than 1
- // ack, 1 for the non-zero window
- p := c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.TCPAckNum(wantAckNum),
- func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- // We use 10% here as the error margin upwards as the initial window we
- // got was afer 1 segment was already in the receive buffer queue.
- tolerance := 1.1
- if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) {
- t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance))
- }
- },
- ))
-}
-
-// This test verifies that the advertised window is auto-tuned up as the
-// application is reading the data that is being received.
-func TestReceiveBufferAutoTuning(t *testing.T) {
- const mtu = 1500
- const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
-
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Enable Auto-tuning.
- stk := c.Stack()
- // Disable out of window rate limiting for this test by setting it to 0 as we
- // use out of window ACKs to measure the advertised window.
- var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption
- if err := stk.SetOption(tcpInvalidRateLimit); err != nil {
- t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err)
- }
-
- const receiveBufferSize = 80 << 10 // 80KB.
- const maxReceiveBufferSize = receiveBufferSize * 10
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- // Enable auto-tuning.
- {
- opt := tcpip.TCPModerateReceiveBufferOption(true)
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- // Change the expected window scale to match the value needed for the
- // maximum buffer size used by stack.
- c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
-
- rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
- tsVal := rawEP.TSVal
- rawEP.NextSeqNum--
- rawEP.SendPacketWithTS(nil, tsVal)
- rawEP.NextSeqNum++
- pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
- curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
- scaleRcvWnd := func(rcvWnd int) uint16 {
- return uint16(rcvWnd >> c.WindowScale)
- }
- // Allocate a large array to send to the endpoint.
- b := make([]byte, receiveBufferSize*48)
-
- // In every iteration we will send double the number of bytes sent in
- // the previous iteration and read the same from the app. The received
- // window should grow by at least 2x of bytes read by the app in every
- // RTT.
- offset := 0
- payloadSize := receiveBufferSize / 8
- worker := (c.EP).(interface {
- StopWork()
- ResumeWork()
- })
- latency := 1 * time.Millisecond
- for i := 0; i < 5; i++ {
- tsVal++
-
- // Stop the worker goroutine.
- worker.StopWork()
- start := offset
- end := offset + payloadSize
- totalSent := 0
- packetsSent := 0
- for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- totalSent += mss
- packetsSent++
- }
-
- // Resume it so that it only sees the packets once all of them
- // are waiting to be read.
- worker.ResumeWork()
-
- // Give 1ms for the worker to process the packets.
- time.Sleep(1 * time.Millisecond)
-
- lastACK := c.GetPacket()
- // Discard any intermediate ACKs and only check the last ACK we get in a
- // short time period of few ms.
- for {
- time.Sleep(1 * time.Millisecond)
- pkt := c.GetPacketNonBlocking()
- if pkt == nil {
- break
- }
- lastACK = pkt
- }
- if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want {
- t.Fatalf("advertised window got: %d, want <= %d", got, want)
- }
-
- // Now read all the data from the endpoint and invoke the
- // moderation API to allow for receive buffer auto-tuning
- // to happen before we measure the new window.
- totalCopied := 0
- for {
- res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- break
- }
- totalCopied += res.Count
- }
-
- // Invoke the moderation API. This is required for auto-tuning
- // to happen. This method is normally expected to be invoked
- // from a higher layer than tcpip.Endpoint. So we simulate
- // copying to userspace by invoking it explicitly here.
- c.EP.ModerateRecvBuf(totalCopied)
-
- // Now send a keep-alive packet to trigger an ACK so that we can
- // measure the new window.
- rawEP.NextSeqNum--
- rawEP.SendPacketWithTS(nil, tsVal)
- rawEP.NextSeqNum++
-
- if i == 0 {
- // In the first iteration the receiver based RTT is not
- // yet known as a result the moderation code should not
- // increase the advertised window.
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
- } else {
- // Read loop above could generate an ACK if the window had dropped to
- // zero and then read had opened it up.
- lastACK := c.GetPacket()
- // Discard any intermediate ACKs and only check the last ACK we get in a
- // short time period of few ms.
- for {
- time.Sleep(1 * time.Millisecond)
- pkt := c.GetPacketNonBlocking()
- if pkt == nil {
- break
- }
- lastACK = pkt
- }
- curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale
- // If thew new current window is close maxReceiveBufferSize then terminate
- // the loop. This can happen before all iterations are done due to timing
- // differences when running the test.
- if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 {
- break
- }
- // Increase the latency after first two iterations to
- // establish a low RTT value in the receiver since it
- // only tracks the lowest value. This ensures that when
- // ModerateRcvBuf is called the elapsed time is always >
- // rtt. Without this the test is flaky due to delays due
- // to scheduling/wakeup etc.
- latency += 50 * time.Millisecond
- }
- time.Sleep(latency)
- offset += payloadSize
- payloadSize *= 2
- }
- // Check that at the end of our iterations the receive window grew close to the maximum
- // permissible size of maxReceiveBufferSize/2
- if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want {
- t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want)
- }
-
-}
-
-func TestDelayEnabled(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- checkDelayOption(t, c, false, false) // Delay is disabled by default.
-
- for _, delayEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- opt := tcpip.TCPDelayEnabled(delayEnabled)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err)
- }
- checkDelayOption(t, c, opt, delayEnabled)
- })
- }
-}
-
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) {
- t.Helper()
-
- var gotDelayEnabled tcpip.TCPDelayEnabled
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
- t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
- }
- if gotDelayEnabled != wantDelayEnabled {
- t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
- }
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
- if err != nil {
- t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
- }
- gotDelayOption := ep.SocketOptions().GetDelayOption()
- if gotDelayOption != wantDelayOption {
- t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
- }
-}
-
-func TestTCPLingerTimeout(t *testing.T) {
- c := context.New(t, 1500 /* mtu */)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- testCases := []struct {
- name string
- tcpLingerTimeout time.Duration
- want time.Duration
- }{
- {"NegativeLingerTimeout", -123123, -1},
- // Zero is treated same as the stack's default TCP_LINGER2 timeout.
- {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout},
- {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
- // Values > stack's TCPLingerTimeout are capped to the stack's
- // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)
- if err := c.EP.SetSockOpt(&v); err != nil {
- t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err)
- }
-
- v = 0
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockOpt(&%T) = %s", v, err)
- }
- if got, want := time.Duration(v), tc.want; got != want {
- t.Fatalf("got linger timeout = %s, want = %s", got, want)
- }
- })
- }
-}
-
-func TestTCPTimeWaitRSTIgnored(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Now send a RST and this should be ignored and not
- // generate an ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagRst,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- })
-
- c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
-
- // Out of order ACK should generate an immediate ACK in
- // TIME_WAIT.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 3,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-}
-
-func TestTCPTimeWaitOutOfOrder(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Out of order ACK should generate an immediate ACK in
- // TIME_WAIT.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 3,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-}
-
-func TestTCPTimeWaitNewSyn(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Send a SYN request w/ sequence number lower than
- // the highest sequence number sent. We just reuse
- // the same number.
- iss = seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
-
- // drain any older notifications from the notification channel before attempting
- // 2nd connection.
- select {
- case <-ch:
- default:
- }
-
- // Send a SYN request w/ sequence number higher than
- // the highest sequence number sent.
- iss = iss.Add(3)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b = c.GetPacket()
- tcpHdr = header.IPv4(b).Payload()
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-}
-
-func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
- // after 5 seconds in TIME_WAIT state.
- tcpTimeWaitTimeout := 5 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
- }
-
- want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- time.Sleep(2 * time.Second)
-
- // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
- // by another 5 seconds and also send us a duplicate ACK as it should
- // indicate that the final ACK was potentially lost.
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Sleep for 4 seconds so at this point we are 1 second past the
- // original tcpLingerTimeout of 5 seconds.
- time.Sleep(4 * time.Second)
-
- // Send an ACK and it should not generate any packet as the socket
- // should still be in TIME_WAIT for another another 5 seconds due
- // to the duplicate FIN we sent earlier.
- *ackHeaders = *finHeaders
- ackHeaders.SeqNum = ackHeaders.SeqNum + 1
- ackHeaders.Flags = header.TCPFlagAck
- c.SendPacket(nil, ackHeaders)
-
- c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
- // Now sleep for another 2 seconds so that we are past the
- // extended TIME_WAIT of 7 seconds (2 + 5).
- time.Sleep(2 * time.Second)
-
- // Resend the same ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Receive the RST that should be generated as there is no valid
- // endpoint.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-
- if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
- }
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
-}
-
-func TestTCPCloseWithData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
- // after 5 seconds in TIME_WAIT state.
- tcpTimeWaitTimeout := 5 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
- }
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- RcvWnd: 30000,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now trigger a passive close by sending a FIN.
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- RcvWnd: 30000,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Now write a few bytes and then close the endpoint.
- data := []byte{1, 2, 3}
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b = c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-
- c.EP.Close()
- // Check the FIN.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))),
- checker.TCPAckNum(uint32(iss+2)),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- // First send a partial ACK.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Now send a full ACK.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Now ACK the FIN.
- ackHeaders.AckNum++
- c.SendPacket(nil, ackHeaders)
-
- // Now send an ACK and we should get a RST back as the endpoint should
- // be in CLOSED state.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Check the RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestTCPUserTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- initRTO := 1 * time.Second
- minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
-
- // Ensure that on the next retransmit timer fire, the user timeout has
- // expired.
- userTimeout := initRTO / 2
- v := tcpip.TCPUserTimeoutOption(userTimeout)
- if err := c.EP.SetSockOpt(&v); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err)
- }
-
- // Send some data and wait before ACKing it.
- view := make([]byte, 3)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Wait for the retransmit timer to be fired and the user timeout to cause
- // close of the connection.
- select {
- case <-notifyCh:
- case <-time.After(2 * initRTO):
- t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout)
- }
-
- // No packet should be received as the connection should be silently
- // closed due to timeout.
- c.CheckNoPacket("unexpected packet received after userTimeout has expired")
-
- next += uint32(len(view))
-
- // The connection should be terminated after userTimeout has expired.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrTimeout{})
-
- if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func TestKeepaliveWithUserTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
-
- const keepAliveIdle = 100 * time.Millisecond
- const keepAliveInterval = 3 * time.Second
- keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle)
- if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err)
- }
- keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval)
- if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err)
- }
- if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
- t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
- }
- c.EP.SocketOptions().SetKeepAlive(true)
-
- // Set userTimeout to be the duration to be 1 keepalive
- // probes. Which means that after the first probe is sent
- // the second one should cause the connection to be
- // closed due to userTimeout being hit.
- userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval)
- if err := c.EP.SetSockOpt(&userTimeout); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err)
- }
-
- // Check that the connection is still alive.
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Now receive 1 keepalives, but don't ACK it.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Sleep for a litte over the KeepAlive interval to make sure
- // the timer has time to fire after the last ACK and close the
- // close the socket.
- time.Sleep(keepAliveInterval + keepAliveInterval/2)
-
- // The connection should be closed with a timeout.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS + 1,
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- ept.CheckReadError(t, &tcpip.ErrTimeout{})
- if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func TestIncreaseWindowOnRead(t *testing.T) {
- // This test ensures that the endpoint sends an ack,
- // after read() when the window grows by more than 1 MSS.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBuf = 65535 * 10
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
-
- // Write chunks of ~30000 bytes. It's important that two
- // payloads make it equal or longer than MSS.
- remain := rcvBuf * 2
- sent := 0
- data := make([]byte, defaultMTU/2)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for remain > len(data) {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Break once the window drops below defaultMTU/2
- if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 {
- break
- }
- }
-
- // We now have < 1 MSS in the buffer space. Read at least > 2 MSS
- // worth of data as receive buffer space
- w := tcpip.LimitedWriter{
- W: ioutil.Discard,
- // defaultMTU is a good enough estimate for the MSS used for this
- // connection.
- N: defaultMTU * 2,
- }
- for w.N != 0 {
- _, err := c.EP.Read(&w, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
-
- // After reading > MSS worth of data, we surely crossed MSS. See the ack:
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindow(uint16(0xffff)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestIncreaseWindowOnBufferResize(t *testing.T) {
- // This test ensures that the endpoint sends an ack,
- // after available recv buffer grows to more than 1 MSS.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBuf = 65535 * 10
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
-
- // Write chunks of ~30000 bytes. It's important that two
- // payloads make it equal or longer than MSS.
- remain := rcvBuf
- sent := 0
- data := make([]byte, defaultMTU/2)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for remain > len(data) {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindowLessThanEq(0xffff),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Increasing the buffer from should generate an ACK,
- // since window grew from small value to larger equal MSS
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindow(uint16(0xffff)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestTCPDeferAccept(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- const tcpDeferAccept = 1 * time.Second
- tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
- if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- _, _, err := c.EP.Accept(nil)
- if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
- t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
- }
-
- // Send data. This should result in an acceptable endpoint.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-
- // Give a bit of time for the socket to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
- }
-
- aep.Close()
- // Closing aep without reading the data should trigger a RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-func TestTCPDeferAcceptTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- const tcpDeferAccept = 1 * time.Second
- tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
- if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- _, _, err := c.EP.Accept(nil)
- if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
- t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
- }
-
- // Sleep for a little of the tcpDeferAccept timeout.
- time.Sleep(tcpDeferAccept + 100*time.Millisecond)
-
- // On timeout expiry we should get a SYN-ACK retransmission.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1)))
-
- // Send data. This should result in an acceptable endpoint.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-
- // Give sometime for the endpoint to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
- }
-
- aep.Close()
- // Closing aep without reading the data should trigger a RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-func TestResetDuringClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */)
- // Send some data to make sure there is some unread
- // data to trigger a reset on c.Close.
- irs := c.IRS
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: irs.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(irs.Add(1))),
- checker.TCPAckNum(uint32(iss)+4)))
-
- // Close in a separate goroutine so that we can trigger
- // a race with the RST we send below. This should not
- // panic due to the route being released depeding on
- // whether Close() sends an active RST or the RST sent
- // below is processed by the worker first.
- var wg sync.WaitGroup
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(4),
- AckNum: c.IRS.Add(5),
- RcvWnd: 30000,
- Flags: header.TCPFlagRst,
- })
- }()
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- c.EP.Close()
- }()
-
- wg.Wait()
-}
-
-func TestStackTimeWaitReuse(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- s := c.Stack()
- var twReuse tcpip.TCPTimeWaitReuseOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
- }
- if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
- t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
- }
-}
-
-func TestSetStackTimeWaitReuse(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- s := c.Stack()
- testCases := []struct {
- v int
- err tcpip.Error
- }{
- {int(tcpip.TCPTimeWaitReuseDisabled), nil},
- {int(tcpip.TCPTimeWaitReuseGlobal), nil},
- {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
- {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}},
- {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}},
- }
-
- for _, tc := range testCases {
- opt := tcpip.TCPTimeWaitReuseOption(tc.v)
- err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)
- if got, want := err, tc.err; got != want {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err)
- }
- if tc.err != nil {
- continue
- }
-
- var twReuse tcpip.TCPTimeWaitReuseOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
- }
-
- if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
- t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
- }
- }
-}
-
-func TestHandshakeRTT(t *testing.T) {
- type testCase struct {
- connect bool
- tsEnabled bool
- useCookie bool
- retrans bool
- delay time.Duration
- wantRTT time.Duration
- }
- var testCases []testCase
- for _, connect := range []bool{false, true} {
- for _, tsEnabled := range []bool{false, true} {
- for _, useCookie := range []bool{false, true} {
- for _, retrans := range []bool{false, true} {
- if connect && useCookie {
- continue
- }
- delay := 800 * time.Millisecond
- if retrans {
- delay = 1200 * time.Millisecond
- }
- wantRTT := delay
- // If syncookie is enabled, sample RTT only when TS option is enabled.
- if !retrans && useCookie && !tsEnabled {
- wantRTT = 0
- }
- // If retransmitted, sample RTT only when TS option is enabled.
- if retrans && !tsEnabled {
- wantRTT = 0
- }
- testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT})
- }
- }
- }
- }
- for _, tt := range testCases {
- tt := tt
- t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) {
- t.Parallel()
- c := context.New(t, defaultMTU)
- if tt.useCookie {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- synOpts := header.TCPSynOptions{}
- if tt.tsEnabled {
- synOpts.TS = true
- synOpts.TSVal = 42
- }
- if tt.connect {
- c.CreateConnectedWithOptions(synOpts, tt.delay)
- } else {
- synOpts.MSS = defaultIPv4MSS
- synOpts.WS = -1
- c.AcceptWithOptions(-1, synOpts, tt.delay)
- }
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
- }
- if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT {
- t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT)
- }
- if info.RTTVar != 0 && tt.wantRTT == 0 {
- t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar)
- }
- if info.RTTVar == 0 && tt.wantRTT != 0 {
- t.Fatalf("got info.RTTVar=0, expect non zero")
- }
- })
- }
-}
-
-func TestSetRTO(t *testing.T) {
- c := context.New(t, defaultMTU)
- minRTO, maxRTO := tcpRTOMinMax(t, c)
- for _, tt := range []struct {
- name string
- RTO time.Duration
- minRTO time.Duration
- maxRTO time.Duration
- err tcpip.Error
- }{
- {
- name: "invalid minRTO",
- minRTO: maxRTO + time.Second,
- err: &tcpip.ErrInvalidOptionValue{},
- },
- {
- name: "invalid maxRTO",
- maxRTO: minRTO - time.Millisecond,
- err: &tcpip.ErrInvalidOptionValue{},
- },
- {
- name: "valid minRTO",
- minRTO: maxRTO - time.Second,
- },
- {
- name: "valid maxRTO",
- maxRTO: minRTO + time.Millisecond,
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- var opt tcpip.SettableTransportProtocolOption
- if tt.minRTO > 0 {
- min := tcpip.TCPMinRTOOption(tt.minRTO)
- opt = &min
- }
- if tt.maxRTO > 0 {
- max := tcpip.TCPMaxRTOOption(tt.maxRTO)
- opt = &max
- }
- err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt)
- if got, want := err, tt.err; got != want {
- t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want)
- }
- if tt.err == nil {
- minRTO, maxRTO := tcpRTOMinMax(t, c)
- if tt.minRTO > 0 && tt.minRTO != minRTO {
- t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO)
- }
- if tt.maxRTO > 0 && tt.maxRTO != maxRTO {
- t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO)
- }
- }
- })
- }
-}
-
-func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) {
- t.Helper()
- var minOpt tcpip.TCPMinRTOOption
- var maxOpt tcpip.TCPMaxRTOOption
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil {
- t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err)
- }
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil {
- t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err)
- }
- return time.Duration(minOpt), time.Duration(maxOpt)
-}
-
-// generateRandomPayload generates a random byte slice of the specified length
-// causing a fatal test failure if it is unable to do so.
-func generateRandomPayload(t *testing.T, n int) []byte {
- t.Helper()
- buf := make([]byte, n)
- if _, err := rand.Read(buf); err != nil {
- t.Fatalf("rand.Read(buf) failed: %s", err)
- }
- return buf
-}
-
-func TestSendBufferTuning(t *testing.T) {
- const maxPayload = 536
- const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
- const packetOverheadFactor = 2
-
- testCases := []struct {
- name string
- autoTuningDisabled bool
- }{
- {"autoTuningDisabled", true},
- {"autoTuningEnabled", false},
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Set the stack option for send buffer size.
- const defaultSndBufSz = maxPayload * tcp.InitialCwnd
- const maxSndBufSz = defaultSndBufSz * 10
- {
- opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz}
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- oldSz := c.EP.SocketOptions().GetSendBufferSize()
- if oldSz != defaultSndBufSz {
- t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz)
- }
-
- if tc.autoTuningDisabled {
- c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */)
- }
-
- data := make([]byte, maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- w, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&w, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&w)
-
- bytesRead := 0
- for {
- // Packets will be sent till the send buffer
- // size is reached.
- var r bytes.Reader
- r.Reset(data[bytesRead : bytesRead+maxPayload])
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- break
- }
-
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0)
- bytesRead += maxPayload
- data = append(data, data...)
- }
-
- // Send an ACK and wait for connection to become writable again.
- c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-
- outSz := int64(defaultSndBufSz)
- if !tc.autoTuningDisabled {
- // Calculate the new auto tuned send buffer.
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
- outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload
- }
-
- if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz {
- t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz)
- }
- })
- }
-}
-
-func TestTimestampSynCookies(t *testing.T) {
- clock := faketime.NewManualClock()
- tsNow := func() uint32 {
- return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds())
- }
- // Advance the clock so that NowMonotonic is non-zero.
- clock.Advance(time.Second)
- c := context.NewWithOpts(t, context.Options{
- EnableV4: true,
- EnableV6: true,
- MTU: defaultMTU,
- Clock: clock,
- })
- defer c.Cleanup()
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(42, 0, tcpOpts[2:])
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- RcvWnd: seqnum.Size(512),
- SeqNum: iss,
- TCPOpts: tcpOpts[:],
- })
- // Get the TSVal of SYN-ACK.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- initialTSVal := tcpHdr.ParsedOptions().TSVal
- // derive the tsOffset.
- tsOffset := initialTSVal - tsNow()
-
- header.EncodeTSOption(420, initialTSVal, tcpOpts[2:])
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- RcvWnd: seqnum.Size(512),
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- TCPOpts: tcpOpts[:],
- })
- c.EP, _, err = ep.Accept(nil)
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- } else if err != nil {
- t.Fatalf("failed to accept: %s", err)
- }
-
- // Advance the clock again so that we expect the next TSVal to change.
- clock.Advance(time.Second)
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // The endpoint should have a correct TSOffset so that the received TSVal
- // should match our expectation.
- if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want {
- t.Fatalf("got TSVal = %d, want %d", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
deleted file mode 100644
index 65925daa5..000000000
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ /dev/null
@@ -1,311 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// createConnectedWithTimestampOption creates and connects c.ep with the
-// timestamp option enabled.
-func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1})
-}
-
-// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
-// an active connect and sets the TS Echo Reply fields correctly when the
-// SYN-ACK also indicates support for the TS option and provides a TSVal.
-func TestTimeStampEnabledConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- rep := createConnectedWithTimestampOption(c)
-
- // Register for read and validate that we have data to read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // The following tests ensure that TS option once enabled behaves
- // correctly as described in
- // https://tools.ietf.org/html/rfc7323#section-4.3.
- //
- // We are not testing delayed ACKs here, but we do test out of order
- // packet delivery and filling the sequence number hole created due to
- // the out of order packet.
- //
- // The test also verifies that the sequence numbers and timestamps are
- // as expected.
- data := []byte{1, 2, 3}
-
- // First we increment tsVal by a small amount.
- tsVal := rep.TSVal + 100
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- // Next we send an out of order packet.
- rep.NextSeqNum += 3
- tsVal += 200
- rep.SendPacketWithTS(data, tsVal)
-
- // The ACK should contain the original sequenceNumber and an older TS.
- rep.NextSeqNum -= 6
- rep.VerifyACKWithTS(tsVal - 200)
-
- // Next we fill the hole and the returned ACK should contain the
- // cumulative sequence number acking all data sent till now and have the
- // latest timestamp sent below in its TSEcr field.
- tsVal -= 100
- rep.SendPacketWithTS(data, tsVal)
- rep.NextSeqNum += 3
- rep.VerifyACKWithTS(tsVal)
-
- // Increment tsVal by a large value that doesn't result in a wrap around.
- tsVal += 0x7fffffff
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- // Increment tsVal again by a large value which should cause the
- // timestamp value to wrap around. The returned ACK should contain the
- // wrapped around timestamp in its tsEcr field and not the tsVal from
- // the previous packet sent above.
- tsVal += 0x7fffffff
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // There should be 5 views to read and each of them should
- // contain the same data.
- for i := 0; i < 5; i++ {
- buf := make([]byte, len(data))
- w := tcpip.SliceWriter(buf)
- result, err := c.EP.Read(&w, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: len(buf),
- Total: len(buf),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
- }
- if got, want := buf, data; bytes.Compare(got, want) != 0 {
- t.Fatalf("Data is different: got: %v, want: %v", got, want)
- }
- }
-}
-
-// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an
-// active connect but if the SYN-ACK doesn't specify the TS option then
-// timestamp option is not enabled and future packets do not contain a
-// timestamp.
-func TestTimeStampDisabledConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
-}
-
-func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
-
- t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- tsVal := rand.Uint32()
- c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
-
- // Now send some data and validate that timestamp is echoed correctly in the ACK.
- data := []byte{1, 2, 3}
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %s", err)
- }
-
- // Check that data is received and that the timestamp option TSEcr field
- // matches the expected value.
- b := c.GetPacket()
- checker.IPv4(t, b,
- // Add 12 bytes for the timestamp option + 2 NOPs to align at 4
- // byte boundary.
- checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
- checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- checker.TCPTimestampChecker(true, 0, tsVal+1),
- ),
- )
-}
-
-// TestTimeStampEnabledAccept tests that if the SYN on a passive connect
-// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK
-// and echoes the tsVal field of the original SYN in the tcEcr field of the
-// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify
-// that Timestamp option is enabled in both cases if requested in the original
-// SYN.
-func TestTimeStampEnabledAccept(t *testing.T) {
- testCases := []struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }{
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that.
- {false, 5, 0x4000},
- }
- for _, tc := range testCases {
- timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
- }
-}
-
-func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
-
- t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Now send some data with the accepted connection endpoint and validate
- // that no timestamp option is sent in the TCP segment.
- data := []byte{1, 2, 3}
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %s", err)
- }
-
- // Check that data is received and that the timestamp option is disabled
- // when SYN cookies are enabled/disabled.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
- checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- checker.TCPTimestampChecker(false, 0, 0),
- ),
- )
-}
-
-// TestTimeStampDisabledAccept tests that Timestamp option is not used when the
-// peer doesn't advertise it and connection is established with Accept().
-func TestTimeStampDisabledAccept(t *testing.T) {
- testCases := []struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }{
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of
- // that.
- {false, 5, 0x4000},
- }
- for _, tc := range testCases {
- timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
- }
-}
-
-func TestSendGreaterThanMTUWithOptions(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- createConnectedWithTimestampOption(c)
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- rep := createConnectedWithTimestampOption(c)
-
- // Register for read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- droppedPacketsStat := c.Stack().Stats().DroppedPackets
- droppedPackets := droppedPacketsStat.Value()
- data := []byte{1, 2, 3}
- // Send a packet with no TCP options/timestamp.
- rep.SendPacket(data, nil)
-
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Assert that DroppedPackets was not incremented.
- if got, want := droppedPacketsStat.Value(), droppedPackets; got != want {
- t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
- }
-
- // Issue a read and we should data.
- var buf bytes.Buffer
- result, err := c.EP.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
- }
- if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
- t.Fatalf("Data is different: got: %v, want: %v", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go
new file mode 100644
index 000000000..4cb82fcc9
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package tcp
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
deleted file mode 100644
index ce6a2c31d..000000000
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ /dev/null
@@ -1,26 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "context",
- testonly = 1,
- srcs = ["context.go"],
- visibility = [
- "//visibility:public",
- ],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
deleted file mode 100644
index 88bb99354..000000000
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ /dev/null
@@ -1,1268 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package context provides a test context for use in tcp tests. It also
-// provides helper methods to assert/check certain behaviours.
-package context
-
-import (
- "bytes"
- "context"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // StackAddr is the IPv4 address assigned to the stack.
- StackAddr = "\x0a\x00\x00\x01"
-
- // StackPort is used as the listening port in tests for passive
- // connects.
- StackPort = 1234
-
- // TestAddr is the source address for packets sent to the stack via the
- // link layer endpoint.
- TestAddr = "\x0a\x00\x00\x02"
-
- // TestPort is the TCP port used for packets sent to the stack
- // via the link layer endpoint.
- TestPort = 4096
-
- // StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
-
- // TestV6Addr is the source address for packets sent to the stack via
- // the link layer endpoint.
- TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
-
- // StackV4MappedAddr is StackAddr as a mapped v6 address.
- StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
-
- // TestV4MappedAddr is TestAddr as a mapped v6 address.
- TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr
-
- // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
-
- // TestInitialSequenceNumber is the initial sequence number sent in packets that
- // are sent in response to a SYN or in the initial SYN sent to the stack.
- TestInitialSequenceNumber = 789
-)
-
-// StackAddrWithPrefix is StackAddr with its associated prefix length.
-var StackAddrWithPrefix = tcpip.AddressWithPrefix{
- Address: StackAddr,
- PrefixLen: 24,
-}
-
-// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
-var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
- Address: StackV6Addr,
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
-}
-
-// Headers is used to represent the TCP header fields when building a
-// new packet.
-type Headers struct {
- // SrcPort holds the src port value to be used in the packet.
- SrcPort uint16
-
- // DstPort holds the destination port value to be used in the packet.
- DstPort uint16
-
- // SeqNum is the value of the sequence number field in the TCP header.
- SeqNum seqnum.Value
-
- // AckNum represents the acknowledgement number field in the TCP header.
- AckNum seqnum.Value
-
- // Flags are the TCP flags in the TCP header.
- Flags header.TCPFlags
-
- // RcvWnd is the window to be advertised in the ReceiveWindow field of
- // the TCP header.
- RcvWnd seqnum.Size
-
- // TCPOpts holds the options to be sent in the option field of the TCP
- // header.
- TCPOpts []byte
-}
-
-// Options contains options for creating a new test context.
-type Options struct {
- // EnableV4 indicates whether IPv4 should be enabled.
- EnableV4 bool
-
- // EnableV6 indicates whether IPv4 should be enabled.
- EnableV6 bool
-
- // MTU indicates the maximum transmission unit on the link layer.
- MTU uint32
-
- // Clock that is used by Stack.
- Clock tcpip.Clock
-}
-
-// Context provides an initialized Network stack and a link layer endpoint
-// for use in TCP tests.
-type Context struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
-
- // IRS holds the initial sequence number in the SYN sent by endpoint in
- // case of an active connect or the sequence number sent by the endpoint
- // in the SYN-ACK sent in response to a SYN when listening in passive
- // mode.
- IRS seqnum.Value
-
- // Port holds the port bound by EP below in case of an active connect or
- // the listening port number in case of a passive connect.
- Port uint16
-
- // EP is the test endpoint in the stack owned by this context. This endpoint
- // is used in various tests to either initiate an active connect or is used
- // as a passive listening endpoint to accept inbound connections.
- EP tcpip.Endpoint
-
- // Wq is the wait queue associated with EP and is used to block for events
- // on EP.
- WQ waiter.Queue
-
- // TimeStampEnabled is true if ep is connected with the timestamp option
- // enabled.
- TimeStampEnabled bool
-
- // WindowScale is the expected window scale in SYN packets sent by
- // the stack.
- WindowScale uint8
-
- // RcvdWindowScale is the actual window scale sent by the stack in
- // SYN/SYN-ACK.
- RcvdWindowScale uint8
-}
-
-// New allocates and initializes a test context containing a new
-// stack and a link-layer endpoint.
-func New(t *testing.T, mtu uint32) *Context {
- return NewWithOpts(t, Options{
- EnableV4: true,
- EnableV6: true,
- MTU: mtu,
- })
-}
-
-// NewWithOpts allocates and initializes a test context containing a new
-// stack and a link-layer endpoint with specific options.
-func NewWithOpts(t *testing.T, opts Options) *Context {
- if opts.MTU == 0 {
- panic("MTU must be greater than 0")
- }
-
- stackOpts := stack.Options{
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- Clock: opts.Clock,
- }
- if opts.EnableV4 {
- stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
- }
- if opts.EnableV6 {
- stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol)
- }
- s := stack.New(stackOpts)
-
- const sendBufferSize = 1 << 20 // 1 MiB
- const recvBufferSize = 1 << 20 // 1 MiB
- // Allow minimum send/receive buffer sizes to be 1 during tests.
- sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err)
- }
-
- rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err)
- }
-
- // Increase minimum RTO in tests to avoid test flakes due to early
- // retransmit in case the test executors are overloaded and cause timers
- // to fire earlier than expected.
- minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second)
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
-
- // Some of the congestion control tests send up to 640 packets, we so
- // set the channel size to 1000.
- ep := channel.New(1000, opts.MTU, "")
- wep := stack.LinkEndpoint(ep)
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- nicOpts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
- }
- wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, ""))
- if testing.Verbose() {
- wep2 = sniffer.New(channel.New(1000, opts.MTU, ""))
- }
- opts2 := stack.NICOptions{Name: "nic2"}
- if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
- }
-
- var routeTable []tcpip.Route
-
- if opts.EnableV4 {
- v4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: StackAddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
- }
- routeTable = append(routeTable, tcpip.Route{
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- })
- }
-
- if opts.EnableV6 {
- v6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: StackV6AddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
- }
- routeTable = append(routeTable, tcpip.Route{
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- })
- }
-
- s.SetRouteTable(routeTable)
-
- return &Context{
- t: t,
- s: s,
- linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
- }
-}
-
-// Cleanup closes the context endpoint if required.
-func (c *Context) Cleanup() {
- if c.EP != nil {
- c.EP.Close()
- }
- c.Stack().Close()
-}
-
-// Stack returns a reference to the stack in the Context.
-func (c *Context) Stack() *stack.Stack {
- return c.s
-}
-
-// CheckNoPacketTimeout verifies that no packet is received during the time
-// specified by wait.
-func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
- c.t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), wait)
- defer cancel()
- if _, ok := c.linkEP.ReadContext(ctx); ok {
- c.t.Fatal(errMsg)
- }
-}
-
-// CheckNoPacket verifies that no packet is received for 1 second.
-func (c *Context) CheckNoPacket(errMsg string) {
- c.CheckNoPacketTimeout(errMsg, 1*time.Second)
-}
-
-// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies
-// that it is an IPv4 packet with the expected source and destination
-// addresses. If no packet is received in the specified timeout it will return
-// nil.
-func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
- c.t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- return nil
- }
-
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
-
- // Just check that the stack set the transport protocol number for outbound
- // TCP messages.
- // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
- // of the headerinfo.
- if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
- c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize)
- }
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
-}
-
-// GetPacket reads a packet from the link layer endpoint and verifies
-// that it is an IPv4 packet with the expected source and destination
-// addresses.
-func (c *Context) GetPacket() []byte {
- c.t.Helper()
-
- p := c.GetPacketWithTimeout(5 * time.Second)
- if p == nil {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- return p
-}
-
-// GetPacketNonBlocking reads a packet from the link layer endpoint
-// and verifies that it is an IPv4 packet with the expected source
-// and destination address. If no packet is available it will return
-// nil immediately.
-func (c *Context) GetPacketNonBlocking() []byte {
- c.t.Helper()
-
- p, ok := c.linkEP.Read()
- if !ok {
- return nil
- }
-
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
-
- // Just check that the stack set the transport protocol number for outbound
- // TCP messages.
- // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
- // of the headerinfo.
- if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
- c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
-}
-
-// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
-func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) {
- // Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
- if len(buf) > maxTotalSize {
- buf = buf[:maxTotalSize]
- }
-
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
- icmp.SetType(typ)
- icmp.SetCode(code)
- const icmpv4VariableHeaderOffset = 4
- copy(icmp[icmpv4VariableHeaderOffset:], p1)
- copy(icmp[header.ICMPv4PayloadOffset:], p2)
- icmp.SetChecksum(0)
- checksum := ^header.Checksum(icmp, 0 /* initial */)
- icmp.SetChecksum(checksum)
-
- // Inject packet.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// BuildSegment builds a TCP segment based on the given Headers and payload.
-func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
- return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
-}
-
-// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
-// payload and source and destination IPv4 addresses.
-func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
- copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: src,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the TCP header.
- t := header.TCP(buf[header.IPv4MinimumSize:])
- t.Encode(&header.TCPFields{
- SrcPort: h.SrcPort,
- DstPort: h.DstPort,
- SeqNum: uint32(h.SeqNum),
- AckNum: uint32(h.AckNum),
- DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
- Flags: h.Flags,
- WindowSize: uint16(h.RcvWnd),
- })
-
- // Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
-
- // Calculate the TCP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum))
-
- // Inject packet.
- return buf.ToVectorisedView()
-}
-
-// SendSegment sends a TCP segment that has already been built and written to a
-// buffer.VectorisedView.
-func (c *Context) SendSegment(s buffer.VectorisedView) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: s,
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// SendPacket builds and sends a TCP segment(with the provided payload & TCP
-// headers) in an IPv4 packet via the link layer endpoint.
-func (c *Context) SendPacket(payload []byte, h *Headers) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: c.BuildSegment(payload, h),
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
-// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
-// provided source and destination IPv4 addresses.
-func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// SendAck sends an ACK packet.
-func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
- c.SendAckWithSACK(seq, bytesReceived, nil)
-}
-
-// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified.
-func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) {
- options := make([]byte, 40)
- offset := 0
- if len(sackBlocks) > 0 {
- offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
- }
-
- c.SendPacket(nil, &Headers{
- SrcPort: TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seq,
- AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
- RcvWnd: 30000,
- TCPOpts: options[:offset],
- })
-}
-
-// ReceiveAndCheckPacket reads a packet from the link layer endpoint and
-// verifies that the packet packet payload of packet matches the slice
-// of data indicated by offset & size.
-func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
- c.t.Helper()
-
- c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
-}
-
-// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint
-// and verifies that the packet packet payload of packet matches the slice of
-// data indicated by offset & size and skips optlen bytes in addition to the IP
-// TCP headers when comparing the data.
-func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
- c.t.Helper()
-
- b := c.GetPacket()
- checker.IPv4(c.t, b,
- checker.PayloadLen(size+header.TCPMinimumSize+optlen),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- pdata := data[offset:][:size]
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 {
- c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
- }
-}
-
-// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
-// and verifies that the packet packet payload of packet matches the slice of
-// data indicated by offset & size. It returns true if a packet was received and
-// processed.
-func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
- c.t.Helper()
-
- b := c.GetPacketNonBlocking()
- if b == nil {
- return false
- }
- checker.IPv4(c.t, b,
- checker.PayloadLen(size+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- pdata := data[offset:][:size]
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
- c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
- }
- return true
-}
-
-// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
-// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
-// only endpoint instead of a default dual stack socket.
-func (c *Context) CreateV6Endpoint(v6only bool) {
- var err tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- c.EP.SocketOptions().SetV6Only(v6only)
-}
-
-// GetV6Packet reads a single packet from the link layer endpoint of the context
-// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
-func (c *Context) GetV6Packet() []byte {
- c.t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
-}
-
-// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
-// the context.
-func (c *Context) SendV6Packet(payload []byte, h *Headers) {
- c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
-}
-
-// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
-// endpoint of the context using the provided source and destination IPv6
-// addresses.
-func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- TransportProtocol: tcp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- // Initialize the TCP header.
- t := header.TCP(buf[header.IPv6MinimumSize:])
- t.Encode(&header.TCPFields{
- SrcPort: h.SrcPort,
- DstPort: h.DstPort,
- SeqNum: uint32(h.SeqNum),
- AckNum: uint32(h.AckNum),
- DataOffset: header.TCPMinimumSize,
- Flags: h.Flags,
- WindowSize: uint16(h.RcvWnd),
- })
-
- // Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
-
- // Calculate the TCP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum))
-
- // Inject packet.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- })
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt)
-}
-
-// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
- c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
-}
-
-// Connect performs the 3-way handshake for c.EP with the provided Initial
-// Sequence Number (iss) and receive window(rcvWnd) and any options if
-// specified.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-//
-// PreCondition: c.EP must already be created.
-func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
- c.t.Helper()
-
- // Start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- c.t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */)
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- c.SendPacket(nil, &Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: options,
- })
-
- // Receive ACK packet.
- checker.IPv4(c.t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ),
- )
-
- // Wait for connection to be established.
- select {
- case <-notifyCh:
- if err := c.EP.LastError(); err != nil {
- c.t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for connection")
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- c.RcvdWindowScale = uint8(synOpts.WS)
- c.Port = tcpHdr.SourcePort()
-}
-
-// Create creates a TCP endpoint.
-func (c *Context) Create(epRcvBuf int) {
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if epRcvBuf != -1 {
- c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf)*2, true /* notify */)
- }
-}
-
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
- c.Create(epRcvBuf)
- c.Connect(iss, rcvWnd, options)
-}
-
-// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
-// sending data and ACK packets easy while being able to manipulate the sequence
-// numbers and timestamp values as needed.
-type RawEndpoint struct {
- C *Context
- SrcPort uint16
- DstPort uint16
- Flags header.TCPFlags
- NextSeqNum seqnum.Value
- AckNum seqnum.Value
- WndSize seqnum.Size
- RecentTS uint32 // Stores the latest timestamp to echo back.
- TSVal uint32 // TSVal stores the last timestamp sent by this endpoint.
-
- // SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
- SACKPermitted bool
-}
-
-// SendPacketWithTS embeds the provided tsVal in the Timestamp option
-// for the packet to be sent out.
-func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
- r.TSVal = tsVal
- tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
- r.SendPacket(payload, tsOpt[:])
-}
-
-// SendPacket is a small wrapper function to build and send packets.
-func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
- packetHeaders := &Headers{
- SrcPort: r.SrcPort,
- DstPort: r.DstPort,
- Flags: r.Flags,
- SeqNum: r.NextSeqNum,
- AckNum: r.AckNum,
- RcvWnd: r.WndSize,
- TCPOpts: opts,
- }
- r.C.SendPacket(payload, packetHeaders)
- r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
-}
-
-// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches
-// the provided tsVal as well as returns the original packet.
-func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte {
- r.C.t.Helper()
- // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(r.AckNum)),
- checker.TCPAckNum(uint32(r.NextSeqNum)),
- checker.TCPTimestampChecker(true, 0, tsVal),
- ),
- )
- // Store the parsed TSVal from the ack as recentTS.
- tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
- opts := tcpSeg.ParsedOptions()
- r.RecentTS = opts.TSVal
- return ackPacket
-}
-
-// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
-// tsVal.
-func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
- r.C.t.Helper()
- _ = r.VerifyAndReturnACKWithTS(tsVal)
-}
-
-// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
-// matches the provided rcvWnd.
-func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
- r.C.t.Helper()
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(r.AckNum)),
- checker.TCPAckNum(uint32(r.NextSeqNum)),
- checker.TCPWindow(rcvWnd),
- ),
- )
-}
-
-// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
-func (r *RawEndpoint) VerifyACKNoSACK() {
- r.VerifyACKHasSACK(nil)
-}
-
-// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
-func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
- // Read ACK and verify that the TCP options in the segment do
- // not contain a SACK block.
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(r.AckNum)),
- checker.TCPAckNum(uint32(r.NextSeqNum)),
- checker.TCPSACKBlockChecker(sackBlocks),
- ),
- )
-}
-
-// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions
-// without delay.
-func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint {
- return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */)
-}
-
-// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
-// options enabled and returns a RawEndpoint which represents the other end of
-// the connection. It delays before a SYNACK is sent. This makes c.EP have a
-// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps
-// reduce flakiness.
-//
-// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
-// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
- var err tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
- err = c.EP.Connect(testFullAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
- }
- // Receive SYN packet.
- b := c.GetPacket()
- // Validate that the syn has the timestamp option and a valid
- // TS value.
- mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
-
- synChecker := checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- )
- checker.IPv4(c.t, b, synChecker)
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- tcpSeg := header.TCP(header.IPv4(b).Payload())
- synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
-
- // Build options w/ tsVal to be sent in the SYN-ACK.
- synAckOptions := make([]byte, header.TCPOptionsMaximumSize)
- offset := 0
- if wantOptions.WS != -1 {
- offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:])
- }
- if wantOptions.TS {
- offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
- }
- if wantOptions.SACKPermitted {
- offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
- }
-
- offset += header.AddTCPOptionPadding(synAckOptions, offset)
-
- // Build SYN-ACK.
- c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
- iss := seqnum.Value(TestInitialSequenceNumber)
- if delay > 0 {
- // Sleep so that RTT is increased.
- time.Sleep(delay)
- }
- c.SendPacket(nil, &Headers{
- SrcPort: tcpSeg.DestinationPort(),
- DstPort: tcpSeg.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- TCPOpts: synAckOptions[:offset],
- })
-
- // Read ACK.
- var ackPacket []byte
- // Ignore retransimitted SYN packets.
- for {
- packet := c.GetPacket()
- if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 {
- checker.IPv4(c.t, packet, synChecker)
- } else {
- ackPacket = packet
- break
- }
- }
-
- // Verify TCP header fields.
- tcpCheckers := []checker.TransportChecker{
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS) + 1),
- checker.TCPAckNum(uint32(iss) + 1),
- }
-
- // Verify that tsEcr of ACK packet is wantOptions.TSVal if the
- // timestamp option was enabled, if not then we verify that
- // there is no timestamp in the ACK packet.
- if wantOptions.TS {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal))
- } else {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
- }
-
- checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...))
-
- ackSeg := header.TCP(header.IPv4(ackPacket).Payload())
- ackOptions := ackSeg.ParsedOptions()
-
- // Wait for connection to be established.
- select {
- case <-notifyCh:
- if err := c.EP.LastError(); err != nil {
- c.t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for connection")
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Store the source port in use by the endpoint.
- c.Port = tcpSeg.SourcePort()
-
- // Mark in context that timestamp option is enabled for this endpoint.
- c.TimeStampEnabled = true
- c.RcvdWindowScale = uint8(synOptions.WS)
- return &RawEndpoint{
- C: c,
- SrcPort: tcpSeg.DestinationPort(),
- DstPort: tcpSeg.SourcePort(),
- Flags: header.TCPFlagAck | header.TCPFlagPsh,
- NextSeqNum: iss + 1,
- AckNum: c.IRS.Add(1),
- WndSize: 30000,
- RecentTS: ackOptions.TSVal,
- TSVal: wantOptions.TSVal,
- SACKPermitted: wantOptions.SACKPermitted,
- }
-}
-
-// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay.
-func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
- return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */)
-}
-
-// AcceptWithOptions initializes a listening endpoint and connects to it with
-// the provided options enabled. It delays before the final ACK of the 3WHS is
-// sent. It also verifies that the SYN-ACK has the expected values for the
-// provided options.
-//
-// The function returns a RawEndpoint representing the other end of the accepted
-// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if err := ep.Listen(10); err != nil {
- c.t.Fatalf("Listen failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- c.t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for accept")
- }
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- return rep
-}
-
-// PassiveConnect just disables WindowScaling and delegates the call to
-// PassiveConnectWithOptions.
-func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
- synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */)
-}
-
-// PassiveConnectWithOptions initiates a new connection (with the specified TCP
-// options enabled) to the port on which the Context.ep is listening for new
-// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options. The final ACK of the handshake is delayed by specified
-// duration.
-//
-// NOTE: MSS is not a negotiated option and it can be asymmetric
-// in each direction. This function uses the maxPayload to set the MSS to be
-// sent to the peer on a connect and validates that the MSS in the SYN-ACK
-// response is equal to the MTU - (tcphdr len + iphdr len).
-//
-// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
-// value of the window scaling option to be sent in the SYN. If synOptions.WS >
-// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
- c.t.Helper()
- opts := make([]byte, header.TCPOptionsMaximumSize)
- offset := 0
- offset += header.EncodeMSSOption(uint32(maxPayload), opts)
-
- if synOptions.WS >= 0 {
- offset += header.EncodeWSOption(3, opts[offset:])
- }
- if synOptions.TS {
- offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
- }
-
- if synOptions.SACKPermitted {
- offset += header.EncodeSACKPermittedOption(opts[offset:])
- }
-
- paddingToAdd := 4 - offset%4
- // Now add any padding bytes that might be required to quad align the
- // options.
- for i := offset; i < offset+paddingToAdd; i++ {
- opts[i] = header.TCPOptionNOP
- }
- offset += paddingToAdd
-
- // Send a SYN request.
- iss := seqnum.Value(TestInitialSequenceNumber)
- c.SendPacket(nil, &Headers{
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- TCPOpts: opts[:offset],
- })
-
- // Receive the SYN-ACK reply. Make sure MSS and other expected options
- // are present.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */)
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(StackPort),
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(iss) + 1),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
- }
-
- // If TS option was enabled in the original SYN then add a checker to
- // validate the Timestamp option in the SYN-ACK.
- if synOptions.TS {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal))
- } else {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
- }
-
- checker.IPv4(c.t, b, checker.TCP(tcpCheckers...))
- rcvWnd := seqnum.Size(30000)
- ackHeaders := &Headers{
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- RcvWnd: rcvWnd,
- }
-
- // If WS was expected to be in effect then scale the advertised window
- // correspondingly.
- if synOptions.WS > 0 {
- ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
- }
-
- parsedOpts := tcp.ParsedOptions()
- if synOptions.TS {
- // Echo the tsVal back to the peer in the tsEcr field of the
- // timestamp option.
- // Increment TSVal by 1 from the value sent in the SYN and echo
- // the TSVal in the SYN-ACK in the TSEcr field.
- opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
- ackHeaders.TCPOpts = opts[:]
- }
-
- // Send ACK, delay if needed.
- if delay > 0 {
- time.Sleep(delay)
- }
- c.SendPacket(nil, ackHeaders)
-
- c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
- c.Port = StackPort
-
- return &RawEndpoint{
- C: c,
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagPsh | header.TCPFlagAck,
- NextSeqNum: iss + 1,
- AckNum: c.IRS + 1,
- WndSize: rcvWnd,
- SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
- RecentTS: parsedOpts.TSVal,
- TSVal: synOptions.TSVal + 1,
- }
-}
-
-// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
-// for the Stack in the context.
-func (c *Context) SACKEnabled() bool {
- var v tcpip.TCPSACKEnabled
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
- // Stack doesn't support SACK. So just return.
- return false
- }
- return bool(v)
-}
-
-// SetGSOEnabled enables or disables generic segmentation offload.
-func (c *Context) SetGSOEnabled(enable bool) {
- if enable {
- c.linkEP.SupportedGSOKind = stack.HWGSOSupported
- } else {
- c.linkEP.SupportedGSOKind = stack.GSONotSupported
- }
-}
-
-// MSSWithoutOptions returns the value for the MSS used by the stack when no
-// options are in use.
-func (c *Context) MSSWithoutOptions() uint16 {
- return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
-}
-
-// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
-// options are in use for IPv6 packets.
-func (c *Context) MSSWithoutOptionsV6() uint16 {
- return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
-}
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
deleted file mode 100644
index 479752de7..000000000
--- a/pkg/tcpip/transport/tcp/timer_test.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp
-
-import (
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sleep"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
-)
-
-func TestCleanup(t *testing.T) {
- const (
- timerDurationSeconds = 2
- isAssertedTimeoutSeconds = timerDurationSeconds + 1
- )
-
- clock := faketime.NewManualClock()
-
- tmr := timer{}
- w := sleep.Waker{}
- tmr.init(clock, &w)
- tmr.enable(timerDurationSeconds * time.Second)
- tmr.cleanup()
-
- if want := (timer{}); tmr != want {
- t.Errorf("got tmr = %+v, want = %+v", tmr, want)
- }
-
- // The waker should not be asserted.
- for i := 0; i < isAssertedTimeoutSeconds; i++ {
- clock.Advance(time.Second)
- if w.IsAsserted() {
- t.Fatalf("waker asserted unexpectedly")
- }
- }
-}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
deleted file mode 100644
index 3ad6994a7..000000000
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tcpconntrack",
- srcs = ["tcp_conntrack.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
- name = "tcpconntrack_test",
- size = "small",
- srcs = ["tcp_conntrack_test.go"],
- deps = [
- ":tcpconntrack",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
deleted file mode 100644
index 6c5ddc3c7..000000000
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
+++ /dev/null
@@ -1,511 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcpconntrack_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
-)
-
-// connected creates a connection tracker TCB and sets it to a connected state
-// by performing a 3-way handshake.
-func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: iss,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: irw,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: irs,
- AckNum: iss + 1,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: isw,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: iss + 1,
- AckNum: irs + 1,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: irw,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- return &tcb
-}
-
-func TestConnectionRefused(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive RST.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestConnectionRefusedInSynRcvd(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive RST with no ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestConnectionResetInSynRcvd(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send RST with no ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestRetransmitOnSynSent(t *testing.T) {
- // Send initial SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Retransmit the same SYN.
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting)
- }
-}
-
-func TestRetransmitOnSynRcvd(t *testing.T) {
- // Send initial SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN. This will cause the state to go to SYN-RCVD.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Retransmit the original SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Transmit a SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-}
-
-func TestClosedBySelf(t *testing.T) {
- tcb := connected(t, 1234, 789, 30000, 50000)
-
- // Send FIN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 1236,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1236,
- AckNum: 791,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
- }
-}
-
-func TestClosedByPeer(t *testing.T) {
- tcb := connected(t, 1234, 789, 30000, 50000)
-
- // Receive FIN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 791,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 791,
- AckNum: 1236,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer)
- }
-}
-
-func TestSendAndReceiveDataClosedBySelf(t *testing.T) {
- sseq := uint32(1234)
- rseq := uint32(789)
- tcb := connected(t, sseq, rseq, 30000, 50000)
- sseq++
- rseq++
-
- // Send some data.
- tcp := make(header.TCP, header.TCPMinimumSize+1024)
-
- for i := uint32(0); i < 10; i++ {
- // Send some data.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
- sseq += uint32(len(tcp)) - header.TCPMinimumSize
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive ack for data.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
-
- for i := uint32(0); i < 10; i++ {
- // Receive some data.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
- rseq += uint32(len(tcp)) - header.TCPMinimumSize
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ack for data.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
-
- // Send FIN.
- tcp = tcp[:header.TCPMinimumSize]
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
- sseq++
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
- rseq++
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
- }
-}
-
-func TestIgnoreBadResetOnSynSent(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive a RST with a bad ACK, it should not cause the connection to
- // be reset.
- acks := []uint32{1234, 1236, 1000, 5000}
- flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
- for _, a := range acks {
- for _, f := range flags {
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: a,
- DataOffset: header.TCPMinimumSize,
- Flags: f,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
- }
-
- // Complete the handshake.
- // Receive SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go
new file mode 100644
index 000000000..ff53204da
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package tcpconntrack
diff --git a/pkg/tcpip/transport/transport_state_autogen.go b/pkg/tcpip/transport/transport_state_autogen.go
new file mode 100644
index 000000000..c023165ec
--- /dev/null
+++ b/pkg/tcpip/transport/transport_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package transport
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
deleted file mode 100644
index d2c0963b0..000000000
--- a/pkg/tcpip/transport/udp/BUILD
+++ /dev/null
@@ -1,68 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "udp_packet_list",
- out = "udp_packet_list.go",
- package = "udp",
- prefix = "udpPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*udpPacket",
- "Linker": "*udpPacket",
- },
-)
-
-go_library(
- name = "udp",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "forwarder.go",
- "protocol.go",
- "udp_packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport",
- "//pkg/tcpip/transport/internal/network",
- "//pkg/tcpip/transport/raw",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "udp_x_test",
- size = "small",
- srcs = ["udp_test.go"],
- deps = [
- ":udp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@org_golang_x_time//rate:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go
new file mode 100644
index 000000000..c396f77c9
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_packet_list.go
@@ -0,0 +1,221 @@
+package udp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type udpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type udpPacketList struct {
+ head *udpPacket
+ tail *udpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *udpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *udpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *udpPacketList) Front() *udpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *udpPacketList) Back() *udpPacket {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *udpPacketList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (udpPacketElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *udpPacketList) PushFront(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *udpPacketList) PushBack(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *udpPacketList) PushBackList(m *udpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
+ bLinker := udpPacketElementMapper{}.linkerFor(b)
+ eLinker := udpPacketElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
+ aLinker := udpPacketElementMapper{}.linkerFor(a)
+ eLinker := udpPacketElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *udpPacketList) Remove(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type udpPacketEntry struct {
+ next *udpPacket
+ prev *udpPacket
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) Next() *udpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) Prev() *udpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) SetNext(elem *udpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
new file mode 100644
index 000000000..3a661e327
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -0,0 +1,197 @@
+// automatically generated by stateify.
+
+package udp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *udpPacket) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacket"
+}
+
+func (p *udpPacket) StateFields() []string {
+ return []string{
+ "udpPacketEntry",
+ "netProto",
+ "senderAddress",
+ "destinationAddress",
+ "packetInfo",
+ "data",
+ "receivedAt",
+ "tos",
+ }
+}
+
+func (p *udpPacket) beforeSave() {}
+
+// +checklocksignore
+func (p *udpPacket) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(5, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(6, receivedAtValue)
+ stateSinkObject.Save(0, &p.udpPacketEntry)
+ stateSinkObject.Save(1, &p.netProto)
+ stateSinkObject.Save(2, &p.senderAddress)
+ stateSinkObject.Save(3, &p.destinationAddress)
+ stateSinkObject.Save(4, &p.packetInfo)
+ stateSinkObject.Save(7, &p.tos)
+}
+
+func (p *udpPacket) afterLoad() {}
+
+// +checklocksignore
+func (p *udpPacket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.udpPacketEntry)
+ stateSourceObject.Load(1, &p.netProto)
+ stateSourceObject.Load(2, &p.senderAddress)
+ stateSourceObject.Load(3, &p.destinationAddress)
+ stateSourceObject.Load(4, &p.packetInfo)
+ stateSourceObject.Load(7, &p.tos)
+ stateSourceObject.LoadValue(5, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(6, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "uniqueID",
+ "net",
+ "stats",
+ "ops",
+ "rcvReady",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "lastError",
+ "portFlags",
+ "boundBindToDevice",
+ "boundPortFlags",
+ "readShutdown",
+ "effectiveNetProtos",
+ "frozen",
+ "localPort",
+ "remotePort",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &e.waiterQueue)
+ stateSinkObject.Save(2, &e.uniqueID)
+ stateSinkObject.Save(3, &e.net)
+ stateSinkObject.Save(4, &e.stats)
+ stateSinkObject.Save(5, &e.ops)
+ stateSinkObject.Save(6, &e.rcvReady)
+ stateSinkObject.Save(7, &e.rcvList)
+ stateSinkObject.Save(8, &e.rcvBufSize)
+ stateSinkObject.Save(9, &e.rcvClosed)
+ stateSinkObject.Save(10, &e.lastError)
+ stateSinkObject.Save(11, &e.portFlags)
+ stateSinkObject.Save(12, &e.boundBindToDevice)
+ stateSinkObject.Save(13, &e.boundPortFlags)
+ stateSinkObject.Save(14, &e.readShutdown)
+ stateSinkObject.Save(15, &e.effectiveNetProtos)
+ stateSinkObject.Save(16, &e.frozen)
+ stateSinkObject.Save(17, &e.localPort)
+ stateSinkObject.Save(18, &e.remotePort)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &e.waiterQueue)
+ stateSourceObject.Load(2, &e.uniqueID)
+ stateSourceObject.Load(3, &e.net)
+ stateSourceObject.Load(4, &e.stats)
+ stateSourceObject.Load(5, &e.ops)
+ stateSourceObject.Load(6, &e.rcvReady)
+ stateSourceObject.Load(7, &e.rcvList)
+ stateSourceObject.Load(8, &e.rcvBufSize)
+ stateSourceObject.Load(9, &e.rcvClosed)
+ stateSourceObject.Load(10, &e.lastError)
+ stateSourceObject.Load(11, &e.portFlags)
+ stateSourceObject.Load(12, &e.boundBindToDevice)
+ stateSourceObject.Load(13, &e.boundPortFlags)
+ stateSourceObject.Load(14, &e.readShutdown)
+ stateSourceObject.Load(15, &e.effectiveNetProtos)
+ stateSourceObject.Load(16, &e.frozen)
+ stateSourceObject.Load(17, &e.localPort)
+ stateSourceObject.Load(18, &e.remotePort)
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (l *udpPacketList) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacketList"
+}
+
+func (l *udpPacketList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *udpPacketList) beforeSave() {}
+
+// +checklocksignore
+func (l *udpPacketList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *udpPacketList) afterLoad() {}
+
+// +checklocksignore
+func (l *udpPacketList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *udpPacketEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacketEntry"
+}
+
+func (e *udpPacketEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *udpPacketEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *udpPacketEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *udpPacketEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *udpPacketEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*udpPacket)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*udpPacketList)(nil))
+ state.Register((*udpPacketEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
deleted file mode 100644
index b3199489c..000000000
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ /dev/null
@@ -1,2602 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package udp_test
-
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "math/rand"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "golang.org/x/time/rate"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// Addresses and ports used for testing. It is recommended that tests stick to
-// using these addresses as it allows using the testFlow helper.
-// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
-// represents the remote endpoint.
-const (
- v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
- stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
- testV4MappedAddr = v4MappedAddrPrefix + testAddr
- multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
- broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
- v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
- invalidPort = 8192
- multicastAddr = "\xe8\x2b\xd3\xea"
- multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- broadcastAddr = header.IPv4Broadcast
- testTOS = 0x80
-
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65536
-)
-
-// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
-// a packet header. These values are used to populate a header or verify one.
-// Note that because they are used in packet headers, the addresses are never in
-// a V4-mapped format.
-type header4Tuple struct {
- srcAddr tcpip.FullAddress
- dstAddr tcpip.FullAddress
-}
-
-// testFlow implements a helper type used for sending and receiving test
-// packets. A given test flow value defines 1) the socket endpoint used for the
-// test and 2) the type of packet send or received on the endpoint. E.g., a
-// multicastV6Only flow is a V6 multicast packet passing through a V6-only
-// endpoint. The type provides helper methods to characterize the flow (e.g.,
-// isV4) as well as return a proper header4Tuple for it.
-type testFlow int
-
-const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
- reverseMulticast4 // V4 multicast src. Must fail.
- reverseMulticast6 // V6 multicast src. Must fail.
-)
-
-func (flow testFlow) String() string {
- switch flow {
- case unicastV4:
- return "unicastV4"
- case unicastV6:
- return "unicastV6"
- case unicastV6Only:
- return "unicastV6Only"
- case unicastV4in6:
- return "unicastV4in6"
- case multicastV4:
- return "multicastV4"
- case multicastV6:
- return "multicastV6"
- case multicastV6Only:
- return "multicastV6Only"
- case multicastV4in6:
- return "multicastV4in6"
- case broadcast:
- return "broadcast"
- case broadcastIn6:
- return "broadcastIn6"
- case reverseMulticast4:
- return "reverseMulticast4"
- case reverseMulticast6:
- return "reverseMulticast6"
- default:
- return "unknown"
- }
-}
-
-// packetDirection explains if a flow is incoming (read) or outgoing (write).
-type packetDirection int
-
-const (
- incoming packetDirection = iota
- outgoing
-)
-
-// header4Tuple returns the header4Tuple for the given flow and direction. Note
-// that the tuple contains no mapped addresses as those only exist at the socket
-// level but not at the packet header level.
-func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
- var h header4Tuple
- if flow.isV4() {
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastAddr
- } else if flow.isBroadcast() {
- h.dstAddr.Addr = broadcastAddr
- }
- } else { // IPv6
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastV6Addr
- }
- }
- if flow.isReverseMulticast() {
- h.srcAddr.Addr = flow.getMcastAddr()
- }
- return h
-}
-
-func (flow testFlow) getMcastAddr() tcpip.Address {
- if flow.isV4() {
- return multicastAddr
- }
- return multicastV6Addr
-}
-
-// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
-// if it is applicable to the flow.
-func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
- if flow.isMapped() {
- return v4MappedAddrPrefix + v4Addr
- }
- return v4Addr
-}
-
-// netProto returns the protocol number used for the network packet.
-func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
- if flow.isV4() {
- return ipv4.ProtocolNumber
- }
- return ipv6.ProtocolNumber
-}
-
-// sockProto returns the protocol number used when creating the socket
-// endpoint for this flow.
-func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
- switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
- return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast, reverseMulticast4:
- return ipv4.ProtocolNumber
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
- if flow.isV4() {
- return checker.IPv4
- }
- return checker.IPv6
-}
-
-func (flow testFlow) isV6() bool { return !flow.isV4() }
-func (flow testFlow) isV4() bool {
- return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
-}
-
-func (flow testFlow) isV6Only() bool {
- switch flow {
- case unicastV6Only, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMulticast() bool {
- switch flow {
- case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isBroadcast() bool {
- switch flow {
- case broadcast, broadcastIn6:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMapped() bool {
- switch flow {
- case unicastV4in6, multicastV4in6, broadcastIn6:
- return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isReverseMulticast() bool {
- switch flow {
- case reverseMulticast4, reverseMulticast6:
- return true
- default:
- return false
- }
-}
-
-type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
- nicID tcpip.NICID
-
- ep tcpip.Endpoint
- wq waiter.Queue
-}
-
-func newDualTestContext(t *testing.T, mtu uint32) *testContext {
- t.Helper()
- return newDualTestContextWithHandleLocal(t, mtu, true)
-}
-
-func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
- const nicID = 1
-
- t.Helper()
-
- options := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
- HandleLocal: handleLocal,
- Clock: &faketime.NullClock{},
- }
- s := stack.New(options)
- // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus
- // never allows ICMP messages.
- s.SetICMPLimit(rate.Inf)
- ep := channel.New(256, mtu, "")
- wep := stack.LinkEndpoint(ep)
-
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- if err := s.CreateNIC(nicID, wep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
- }
-
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- return &testContext{
- t: t,
- s: s,
- nicID: nicID,
- linkEP: ep,
- }
-}
-
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
- c.t.Helper()
-
- var err tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
- if err != nil {
- c.t.Fatal("NewEndpoint failed: ", err)
- }
-}
-
-func (c *testContext) createEndpointForFlow(flow testFlow) {
- c.t.Helper()
-
- c.createEndpoint(flow.sockProto())
- if flow.isV6Only() {
- c.ep.SocketOptions().SetV6Only(true)
- } else if flow.isBroadcast() {
- c.ep.SocketOptions().SetBroadcast(true)
- }
-}
-
-// getPacketAndVerify reads a packet from the link endpoint and verifies the
-// header against expected values from the given test flow. In addition, it
-// calls any extra checker functions provided.
-func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
- c.t.Helper()
-
- p, ok := c.linkEP.Read()
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
- c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- h := flow.header4Tuple(outgoing)
- checkers = append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-}
-
-// injectPacket creates a packet of the given flow and with the given payload,
-// and injects it into the link endpoint. If badChecksum is true, the packet has
-// a bad checksum in the UDP header.
-func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) {
- c.t.Helper()
-
- h := flow.header4Tuple(incoming)
- if flow.isV4() {
- buf := c.buildV4Packet(payload, &h)
- if badChecksum {
- // Invalidate the UDP header checksum field, taking care to avoid
- // overflow to zero, which would disable checksum validation.
- for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
- u.SetChecksum(u.Checksum() + 1)
- if u.Checksum() != 0 {
- break
- }
- }
- }
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- } else {
- buf := c.buildV6Packet(payload, &h)
- if badChecksum {
- // Invalidate the UDP header checksum field (Unlike IPv4, zero is
- // a valid checksum value for IPv6 so no need to avoid it).
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetChecksum(u.Checksum() + 1)
- }
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-}
-
-// buildV6Packet creates a V6 test packet with the given payload and header
-// values in a buffer.
-func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- return buf
-}
-
-// buildV4Packet creates a V4 test packet with the given payload and header
-// values in a buffer.
-func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TOS: testTOS,
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- return buf
-}
-
-func newPayload() []byte {
- return newMinPayload(30)
-}
-
-func newMinPayload(minSize int) []byte {
- b := make([]byte, minSize+rand.Intn(100))
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
- return b
-}
-
-func TestBindToDeviceOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
-
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer ep.Close()
-
- opts := stack.NICOptions{Name: "my_device"}
- if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
- }
-
- // nicIDPtr is used instead of taking the address of NICID literals, which is
- // a compiler error.
- nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
- return &s
- }
-
- testActions := []struct {
- name string
- setBindToDevice *tcpip.NICID
- setBindToDeviceError tcpip.Error
- getBindToDevice int32
- }{
- {"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
- {"BindToExistent", nicIDPtr(321), nil, 321},
- {"UnbindToDevice", nicIDPtr(0), nil, 0},
- }
- for _, testAction := range testActions {
- t.Run(testAction.name, func(t *testing.T) {
- if testAction.setBindToDevice != nil {
- bindToDevice := int32(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
- }
- }
- bindToDevice := ep.SocketOptions().GetBindToDevice()
- if bindToDevice != testAction.getBindToDevice {
- t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
- }
- })
- }
-}
-
-// testReadInternal sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then attempts to read it from the
-// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness including any additional checker functions provided.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
-
- payload := newPayload()
- c.injectPacket(flow, payload, false)
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.ReadableEvents)
- defer c.wq.EventUnregister(&we)
-
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- var buf bytes.Buffer
- res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- // Wait for data to become available.
- select {
- case <-ch:
- res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
-
- default:
- if packetShouldBeDropped {
- return // expected to time out
- }
- c.t.Fatal("timed out waiting for data")
- }
- }
-
- if expectReadError && err != nil {
- c.checkEndpointReadStats(1, epstats, err)
- return
- }
-
- if err != nil {
- c.t.Fatal("Read failed:", err)
- }
-
- if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
- }
-
- // Check the read result.
- h := flow.header4Tuple(incoming)
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages", // ControlMessages will be checked later.
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
- }
-
- // Check the payload.
- v := buf.Bytes()
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("got payload = %x, want = %x", v, payload)
- }
-
- // Run any checkers against the ControlMessages.
- for _, f := range checkers {
- f(c.t, res.ControlMessages)
- }
-
- c.checkEndpointReadStats(1, epstats, err)
-}
-
-// testRead sends a packet of the given test flow into the stack by injecting it
-// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness including any additional checker functions provided.
-func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
-}
-
-// testFailingRead sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then tries to read it from the UDP
-// endpoint and expects this to fail.
-func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
- c.t.Helper()
- testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
-}
-
-func TestBindEphemeralPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
-}
-
-func TestBindReservedPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- addr, err := c.ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress failed: %s", err)
- }
-
- // We can't bind the address reserved by the connected endpoint above.
- {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- {
- err := ep.Bind(addr)
- if _, ok := err.(*tcpip.ErrPortInUse); !ok {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
- }
- }
- }
-
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- // We can't bind ipv4-any on the port reserved by the connected endpoint
- // above, since the endpoint is dual-stack.
- {
- err := ep.Bind(tcpip.FullAddress{Port: addr.Port})
- if _, ok := err.(*tcpip.ErrPortInUse); !ok {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
- }
- }
- // We can bind an ipv4 address on this port, though.
- if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
- }()
-
- // Once the connected endpoint releases its port reservation, we are able to
- // bind ipv4-any once again.
- c.ep.Close()
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
- }()
-}
-
-func TestV4ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to local address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV6ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV6)
-}
-
-// TestV4ReadSelfSource checks that packets coming from a local IP address are
-// correctly dropped when handleLocal is true and not otherwise.
-func TestV4ReadSelfSource(t *testing.T) {
- for _, tt := range []struct {
- name string
- handleLocal bool
- wantErr tcpip.Error
- wantInvalidSource uint64
- }{
- {"HandleLocal", false, nil, 0},
- {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
- } {
- t.Run(tt.name, func(t *testing.T) {
- c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- h.srcAddr = h.dstAddr
-
- buf := c.buildV4Packet(payload, &h)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
- t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
- }
-
- if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr {
- t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
- }
- })
- }
-}
-
-func TestV4ReadOnV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4)
-}
-
-// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
-// address and receive data sent to that address.
-func TestReadOnBoundToMulticast(t *testing.T) {
- // FIXME(b/128189410): multicastV4in6 currently doesn't work as
- // AddMembershipOption doesn't handle V4in6 addresses.
- for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to multicast address.
- mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
- if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- // Join multicast group.
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
- }
-
- // Check that we receive multicast packets but not unicast or broadcast
- // ones.
- testRead(c, flow)
- testFailingRead(c, broadcast, false /* expectReadError */)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
-// address and can receive only broadcast data.
-func TestV4ReadOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to broadcast address.
- bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
- if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Check that we receive broadcast packets but not unicast ones.
- testRead(c, flow)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestReadFromMulticast checks that an endpoint will NOT receive a packet
-// that was sent with multicast SOURCE address.
-func TestReadFromMulticast(t *testing.T) {
- for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- testFailingRead(c, flow, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
-// and receive broadcast and unicast data.
-func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s (", err)
- }
-
- // Check that we receive both broadcast and unicast packets.
- testRead(c, flow)
- testRead(c, unicastV4)
- })
- }
-}
-
-// testFailingWrite sends a packet of the given test flow into the UDP endpoint
-// and verifies it fails with the provided error code.
-func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
-
- var r bytes.Reader
- r.Reset(newPayload())
- _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- })
- c.checkEndpointWriteStats(1, epstats, gotErr)
- if gotErr != wantErr {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
- }
-}
-
-// testWrite sends a packet of the given test flow from the UDP endpoint to the
-// flow's destination address:port. It then receives it from the link endpoint
-// and verifies its correctness including any additional checker functions
-// provided.
-func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteAndVerifyInternal(c, flow, true, checkers...)
-}
-
-// testWriteWithoutDestination sends a packet of the given test flow from the
-// UDP endpoint without giving a destination address:port. It then receives it
-// from the link endpoint and verifies its correctness including any additional
-// checker functions provided.
-func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteAndVerifyInternal(c, flow, false, checkers...)
-}
-
-func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- writeOpts := tcpip.WriteOptions{}
- if setDest {
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
- writeOpts = tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- }
- }
- var r bytes.Reader
- payload := newPayload()
- r.Reset(payload)
- n, err := c.ep.Write(&r, writeOpts)
- if err != nil {
- c.t.Fatalf("Write failed: %s", err)
- }
- if n != int64(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
- c.checkEndpointWriteStats(1, epstats, err)
- return payload
-}
-
-func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- payload := testWriteNoVerify(c, flow, setDest)
- // Received the packet and check the payload.
- b := c.getPacketAndVerify(flow, checkers...)
- var udpH header.UDP
- if flow.isV4() {
- udpH = header.IPv4(b).Payload()
- } else {
- udpH = header.IPv6(b).Payload()
- }
- if !bytes.Equal(payload, udpH.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload)
- }
-
- return udpH.SourcePort()
-}
-
-func testDualWrite(c *testContext) uint16 {
- c.t.Helper()
-
- v4Port := testWrite(c, unicastV4in6)
- v6Port := testWrite(c, unicastV6)
- if v4Port != v6Port {
- c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
- }
-
- return v4Port
-}
-
-func TestDualWriteUnbound(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-}
-
-func TestDualWriteBoundToWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- p := testDualWrite(c)
- if p != stackPort {
- c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
- }
-}
-
-func TestDualWriteConnectedToV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, unicastV6)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{})
- const want = 1
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
- c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
- }
-}
-
-func TestDualWriteConnectedToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, unicastV4in6)
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
-}
-
-func TestV4WriteOnV6Only(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6Only)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{})
-}
-
-func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to v4 mapped address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
-}
-
-func TestV6WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- testWriteWithoutDestination(c, unicastV6)
-}
-
-func TestV4WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- testWriteWithoutDestination(c, unicastV4)
-}
-
-func TestWriteOnConnectedInvalidPort(t *testing.T) {
- protocols := map[string]tcpip.NetworkProtocolNumber{
- "ipv4": ipv4.ProtocolNumber,
- "ipv6": ipv6.ProtocolNumber,
- }
- for name, pn := range protocols {
- t.Run(name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(pn)
- if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
- writeOpts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
- }
- var r bytes.Reader
- payload := newPayload()
- r.Reset(payload)
- n, err := c.ep.Write(&r, writeOpts)
- if err != nil {
- c.t.Fatalf("c.ep.Write(...) = %s, want nil", err)
- }
- if got, want := n, int64(len(payload)); got != want {
- c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
- }
-
- {
- err := c.ep.LastError()
- if _, ok := err.(*tcpip.ErrConnectionRefused); !ok {
- c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
- }
- }
- })
- }
-}
-
-// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
-// that is bound to a V4 multicast address.
-func TestWriteOnBoundToV4Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
-// socket that is bound to a V4-mapped multicast address.
-func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6, multicastV6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// V6-only socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToBroadcast checks that we can send packets out of a
-// socket that is bound to the broadcast address.
-func TestWriteOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 broadcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
-// socket that is bound to the V4-mapped broadcast address.
-func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-func TestReadIncrementsPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- // Create IPv4 UDP endpoint
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testRead(c, unicastV4)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
- c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
- }
-}
-
-func TestReadIPPacketInfo(t *testing.T) {
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- checker func(tcpip.NICID) checker.ControlMessagesChecker
- }{
- {
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: id,
- LocalAddr: stackAddr,
- DestinationAddr: stackAddr,
- })
- },
- },
- {
- name: "IPv4 multicast",
- proto: header.IPv4ProtocolNumber,
- flow: multicastV4,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: id,
- // TODO(gvisor.dev/issue/3556): Check for a unicast address.
- LocalAddr: multicastAddr,
- DestinationAddr: multicastAddr,
- })
- },
- },
- {
- name: "IPv4 broadcast",
- proto: header.IPv4ProtocolNumber,
- flow: broadcast,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: id,
- // TODO(gvisor.dev/issue/3556): Check for a unicast address.
- LocalAddr: broadcastAddr,
- DestinationAddr: broadcastAddr,
- })
- },
- },
- {
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
- NIC: id,
- Addr: stackV6Addr,
- })
- },
- },
- {
- name: "IPv6 multicast",
- proto: header.IPv6ProtocolNumber,
- flow: multicastV6,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
- NIC: id,
- Addr: multicastV6Addr,
- })
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(test.proto)
-
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- if test.flow.isMulticast() {
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
- }
- }
-
- switch f := test.flow.netProto(); f {
- case header.IPv4ProtocolNumber:
- c.ep.SocketOptions().SetReceivePacketInfo(true)
- case header.IPv6ProtocolNumber:
- c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true)
- default:
- t.Fatalf("unhandled protocol number = %d", f)
- }
-
- testRead(c, test.flow, test.checker(c.nicID))
-
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
- t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestReadRecvOriginalDstAddr(t *testing.T) {
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedOriginalDstAddr tcpip.FullAddress
- }{
- {
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort},
- },
- {
- name: "IPv4 multicast",
- proto: header.IPv4ProtocolNumber,
- flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort},
- },
- {
- name: "IPv4 broadcast",
- proto: header.IPv4ProtocolNumber,
- flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort},
- },
- {
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort},
- },
- {
- name: "IPv6 multicast",
- proto: header.IPv6ProtocolNumber,
- flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(test.proto)
-
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%#v): %s", bindAddr, err)
- }
-
- if test.flow.isMulticast() {
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
- }
- }
-
- c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
-
- testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
-
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
- t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestWriteIncrementsPacketsSent(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-
- var want uint64 = 2
- if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
- c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
- }
-}
-
-func TestNoChecksum(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Disable the checksum generation.
- c.ep.SocketOptions().SetNoChecksum(true)
- // This option is effective on IPv4 only.
- testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
-
- // Enable the checksum generation.
- c.ep.SocketOptions().SetNoChecksum(false)
- testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
- })
- }
-}
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.NetworkInterface
-}
-
-func (*testInterface) ID() tcpip.NICID {
- return 0
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func TestTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const multicastTTL = 42
- if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
- c.t.Fatalf("SetSockOptInt failed: %s", err)
- }
-
- var wantTTL uint8
- if flow.isMulticast() {
- wantTTL = multicastTTL
- } else {
- var p stack.NetworkProtocolFactory
- var n tcpip.NetworkProtocolNumber
- if flow.isV4() {
- p = ipv4.NewProtocol
- n = ipv4.ProtocolNumber
- } else {
- p = ipv6.NewProtocol
- n = ipv6.ProtocolNumber
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{p},
- Clock: &faketime.NullClock{},
- })
- ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil)
- wantTTL = ep.DefaultTTL()
- ep.Close()
- }
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
-}
-
-func TestSetTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
- t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
- c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
- }
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
- })
- }
-}
-
-var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6}
-
-func TestSetTOS(t *testing.T) {
- for _, flow := range v4PacketFlows {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tos = testTOS
- v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
- }
-
- if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
- c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
- }
-
- v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
- }
-
- if v != tos {
- c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
- }
-
- testWrite(c, flow, checker.TOS(tos, 0))
- })
- }
-}
-
-var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6}
-
-func TestSetTClass(t *testing.T) {
- for _, flow := range v6PacketFlows {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tClass = testTOS
- v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
- }
-
- if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
- c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
- }
-
- v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
- }
-
- if v != tClass {
- c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
- }
-
- // The header getter for TClass is called TOS, so use that checker.
- testWrite(c, flow, checker.TOS(tClass, 0))
- })
- }
-}
-
-func TestReceiveTosTClass(t *testing.T) {
- const RcvTOSOpt = "ReceiveTosOption"
- const RcvTClassOpt = "ReceiveTClassOption"
-
- testCases := []struct {
- name string
- tests []testFlow
- }{
- {
- name: RcvTOSOpt,
- tests: v4PacketFlows[:],
- },
- {
- name: RcvTClassOpt,
- tests: v6PacketFlows[:],
- },
- }
- for _, testCase := range testCases {
- for _, flow := range testCase.tests {
- t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
- name := testCase.name
-
- if flow.isMulticast() {
- netProto := flow.netProto()
- addr := flow.getMcastAddr()
- if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil {
- c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err)
- }
- }
-
- var optionGetter func() bool
- var optionSetter func(bool)
- switch name {
- case RcvTOSOpt:
- optionGetter = c.ep.SocketOptions().GetReceiveTOS
- optionSetter = c.ep.SocketOptions().SetReceiveTOS
- case RcvTClassOpt:
- optionGetter = c.ep.SocketOptions().GetReceiveTClass
- optionSetter = c.ep.SocketOptions().SetReceiveTClass
- default:
- t.Fatalf("unkown test variant: %s", name)
- }
-
- // Verify that setting and reading the option works.
- v := optionGetter()
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
- }
-
- const want = true
- optionSetter(want)
-
- got := optionGetter()
- if got != want {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
- }
-
- // Verify that the correct received TOS or TClass is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
- switch name {
- case RcvTClassOpt:
- testRead(c, flow, checker.ReceiveTClass(testTOS))
- case RcvTOSOpt:
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- default:
- t.Fatalf("unknown test variant: %s", name)
- }
- })
- }
- }
-}
-
-func TestMulticastInterfaceOption(t *testing.T) {
- for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- h := flow.header4Tuple(outgoing)
- mcastAddr := h.dstAddr.Addr
- localIfAddr := h.srcAddr.Addr
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(flow.sockProto())
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: flow.mapAddrIfApplicable(mcastAddr),
- Port: stackPort,
- }
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
- }
-
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
- } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
- c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
- }
-}
-
-// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV4UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true, will result in a payload large enough
- // so that the final generated IPv4 packet is larger than
- // header.IPv4MinimumProcessableDatagramSize.
- largePayload bool
- // badChecksum if true, will set an invalid checksum in the
- // header.
- badChecksum bool
- }{
- {unicastV4, true, false, false},
- {unicastV4, true, true, false},
- {unicastV4, false, false, true},
- {unicastV4, false, true, true},
- {multicastV4, false, false, false},
- {multicastV4, false, true, false},
- {broadcast, false, false, false},
- {broadcast, false, true, false},
- }
- checksumErrors := uint64(0)
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(576)
- }
- c.injectPacket(tc.flow, payload, tc.badChecksum)
- if tc.badChecksum {
- checksumErrors++
- if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
- t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- }
- if !tc.icmpRequired {
- if p, ok := c.linkEP.Read(); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- p, ok := c.linkEP.Read()
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
-
- // We need to compare the included data part of the UDP packet that is in
- // the ICMP packet with the matching original data.
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantLen := len(payload)
- if tc.largePayload {
- // To work out the data size we need to simulate what the sender would
- // have done. The wanted size is the total available minus the sum of
- // the headers in the UDP AND ICMP packets, given that we know the test
- // had only a minimal IP header but the ICMP sender will have allowed
- // for a maximally sized packet header.
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
- }
-
- // In the case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- // Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- })
- }
-}
-
-// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV6UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true will result in a payload large enough to
- // create an IPv6 packet > header.IPv6MinimumMTU bytes.
- largePayload bool
- // badChecksum if true, will set an invalid checksum in the
- // header.
- badChecksum bool
- }{
- {unicastV6, true, false, false},
- {unicastV6, true, true, false},
- {unicastV6, false, false, true},
- {unicastV6, false, true, true},
- {multicastV6, false, false, false},
- {multicastV6, false, true, false},
- }
- checksumErrors := uint64(0)
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(1280)
- }
- c.injectPacket(tc.flow, payload, tc.badChecksum)
- if tc.badChecksum {
- checksumErrors++
- if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
- t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- }
- if !tc.icmpRequired {
- if p, ok := c.linkEP.Read(); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- p, ok := c.linkEP.Read()
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
-
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- })
- }
-}
-
-// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats are incremented.
-func TestIncrementMalformedPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
-
- // Invalidate the UDP header length field.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetLength(u.Length() + 1)
-
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
- }
-}
-
-// TestShortHeader verifies that when a packet with a too-short UDP header is
-// received, the malformed received global stat gets incremented.
-func TestShortHeader(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- h := unicastV6.header4Tuple(incoming)
-
- // Allocate a buffer for an IPv6 and too-short UDP header.
- const udpSize = header.UDPMinimumSize - 1
- buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
-
- // Initialize the UDP header.
- udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
- udpHdr.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: header.UDPMinimumSize,
- })
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
- udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
- // Copy all but the last byte of the UDP header into the packet.
- copy(buf[header.IPv6MinimumSize:], udpHdr)
-
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want {
- t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want)
- }
-}
-
-// TestBadChecksumErrors verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestBadChecksumErrors(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV6} {
- t.Run(flow.String(), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(flow.sockProto())
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- c.injectPacket(flow, payload, true /* badChecksum */)
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
- })
- }
-}
-
-// TestPayloadModifiedV4 verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestPayloadModifiedV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv4.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- buf := c.buildV4Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be
- // incorrect.
- buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestPayloadModifiedV6 verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestPayloadModifiedV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be
- // incorrect.
- buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestChecksumZeroV4 verifies if the checksum value is zero, global and
-// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
-func TestChecksumZeroV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv4.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- buf := c.buildV4Packet(payload, &h)
- // Set the checksum field in the UDP header to zero.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 0
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestChecksumZeroV6 verifies if the checksum value is zero, global and
-// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
-func TestChecksumZeroV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
- // Set the checksum field in the UDP header to zero.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestShutdownRead verifies endpoint read shutdown and error
-// stats increment on packet receive.
-func TestShutdownRead(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- testFailingRead(c, unicastV6, true /* expectReadError */)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
- }
-}
-
-// TestShutdownWrite verifies endpoint write shutdown and error
-// stats increment on packet write.
-func TestShutdownWrite(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{})
-}
-
-func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err.(type) {
- case nil:
- want.PacketsSent.IncrementBy(incr)
- case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
- want.WriteErrors.InvalidArgs.IncrementBy(incr)
- case *tcpip.ErrClosedForSend:
- want.WriteErrors.WriteClosed.IncrementBy(incr)
- case *tcpip.ErrInvalidEndpointState:
- want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
- case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
- want.SendErrors.NoRoute.IncrementBy(incr)
- default:
- want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}
-
-func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- case *tcpip.ErrClosedForReceive:
- want.ReadErrors.ReadClosed.IncrementBy(incr)
- default:
- c.t.Errorf("Endpoint error missing stats update err %v", err)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}
-
-func TestOutgoingSubnetBroadcast(t *testing.T) {
- const nicID1 = 1
-
- ipv4Addr := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 24,
- }
- ipv4Subnet := ipv4Addr.Subnet()
- ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := testutil.MustParse4("192.168.1.1")
- ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 31,
- }
- ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
- ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
- ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 32,
- }
- ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
- ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
- ipv6Addr := tcpip.AddressWithPrefix{
- Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- PrefixLen: 64,
- }
- ipv6Subnet := ipv6Addr.Subnet()
- ipv6SubnetBcast := ipv6Subnet.Broadcast()
- remNetAddr := tcpip.AddressWithPrefix{
- Address: "\x64\x0a\x7b\x18",
- PrefixLen: 24,
- }
- remNetSubnet := remNetAddr.Subnet()
- remNetSubnetBcast := remNetSubnet.Broadcast()
-
- tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- requiresBroadcastOpt bool
- }{
- {
- name: "IPv4 Broadcast to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4SubnetBcast,
- requiresBroadcastOpt: true,
- },
- {
- name: "IPv4 Broadcast to local /31 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix31,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet31,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet31Bcast,
- requiresBroadcastOpt: false,
- },
- {
- name: "IPv4 Broadcast to local /32 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix32,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet32,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet32Bcast,
- requiresBroadcastOpt: false,
- },
- // IPv6 has no notion of a broadcast.
- {
- name: "IPv6 'Broadcast' to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: ipv6Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv6Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv6SubnetBcast,
- requiresBroadcastOpt: false,
- },
- {
- name: "IPv4 Broadcast to remote subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: remNetSubnet,
- Gateway: ipv4Gateway,
- NIC: nicID1,
- },
- },
- remoteAddr: remNetSubnetBcast,
- // TODO(gvisor.dev/issue/3938): Once we support marking a route as
- // broadcast, this test should require the broadcast option to be set.
- requiresBroadcastOpt: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
- }
-
- s.SetRouteTable(test.routes)
-
- var netProto tcpip.NetworkProtocolNumber
- switch l := len(test.remoteAddr); l {
- case header.IPv4AddressSize:
- netProto = header.IPv4ProtocolNumber
- case header.IPv6AddressSize:
- netProto = header.IPv6ProtocolNumber
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
- }
- defer ep.Close()
-
- var r bytes.Reader
- data := []byte{1, 2, 3, 4}
- to := tcpip.FullAddress{
- Addr: test.remoteAddr,
- Port: 80,
- }
- opts := tcpip.WriteOptions{To: &to}
- expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error {
- if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok {
- return nil
- }
- return &tcpip.ErrBroadcastDisabled{}
- }
- if !test.requiresBroadcastOpt {
- expectedErrWithoutBcastOpt = nil
- }
-
- r.Reset(data)
- {
- n, err := ep.Write(&r, opts)
- if expectedErrWithoutBcastOpt != nil {
- if want := expectedErrWithoutBcastOpt(err); want != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
- }
- } else if err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
- }
-
- ep.SocketOptions().SetBroadcast(true)
-
- r.Reset(data)
- if n, err := ep.Write(&r, opts); err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
-
- ep.SocketOptions().SetBroadcast(false)
-
- r.Reset(data)
- {
- n, err := ep.Write(&r, opts)
- if expectedErrWithoutBcastOpt != nil {
- if want := expectedErrWithoutBcastOpt(err); want != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
- }
- } else if err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
- }
- })
- }
-}