diff options
Diffstat (limited to 'pkg/tcpip/transport')
33 files changed, 3005 insertions, 16326 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD deleted file mode 100644 index 7e5c79776..000000000 --- a/pkg/tcpip/transport/icmp/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "icmp_packet_list", - out = "icmp_packet_list.go", - package = "icmp", - prefix = "icmpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*icmpPacket", - "Linker": "*icmpPacket", - }, -) - -go_library( - name = "icmp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "icmp_packet_list.go", - "protocol.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go new file mode 100644 index 000000000..0aacdad3f --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -0,0 +1,221 @@ +package icmp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type icmpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type icmpPacketList struct { + head *icmpPacket + tail *icmpPacket +} + +// Reset resets list l to the empty state. +func (l *icmpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *icmpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *icmpPacketList) Front() *icmpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *icmpPacketList) Back() *icmpPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *icmpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (icmpPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *icmpPacketList) PushFront(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *icmpPacketList) PushBack(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *icmpPacketList) PushBackList(m *icmpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) { + bLinker := icmpPacketElementMapper{}.linkerFor(b) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) { + aLinker := icmpPacketElementMapper{}.linkerFor(a) + eLinker := icmpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *icmpPacketList) Remove(e *icmpPacket) { + linker := icmpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + icmpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type icmpPacketEntry struct { + next *icmpPacket + prev *icmpPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) Next() *icmpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) Prev() *icmpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) SetNext(elem *icmpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go new file mode 100644 index 000000000..98c1b06a3 --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -0,0 +1,167 @@ +// automatically generated by stateify. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *icmpPacket) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacket" +} + +func (p *icmpPacket) StateFields() []string { + return []string{ + "icmpPacketEntry", + "senderAddress", + "data", + "timestamp", + } +} + +func (p *icmpPacket) beforeSave() {} + +// +checklocksignore +func (p *icmpPacket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView = p.saveData() + stateSinkObject.SaveValue(2, dataValue) + stateSinkObject.Save(0, &p.icmpPacketEntry) + stateSinkObject.Save(1, &p.senderAddress) + stateSinkObject.Save(3, &p.timestamp) +} + +func (p *icmpPacket) afterLoad() {} + +// +checklocksignore +func (p *icmpPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.icmpPacketEntry) + stateSourceObject.Load(1, &p.senderAddress) + stateSourceObject.Load(3, &p.timestamp) + stateSourceObject.LoadValue(2, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/icmp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "rcvReady", + "rcvList", + "rcvBufSize", + "rcvClosed", + "shutdownFlags", + "state", + "ttl", + "owner", + "ops", + "frozen", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.uniqueID) + stateSinkObject.Save(4, &e.rcvReady) + stateSinkObject.Save(5, &e.rcvList) + stateSinkObject.Save(6, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.shutdownFlags) + stateSinkObject.Save(9, &e.state) + stateSinkObject.Save(10, &e.ttl) + stateSinkObject.Save(11, &e.owner) + stateSinkObject.Save(12, &e.ops) + stateSinkObject.Save(13, &e.frozen) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.uniqueID) + stateSourceObject.Load(4, &e.rcvReady) + stateSourceObject.Load(5, &e.rcvList) + stateSourceObject.Load(6, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.shutdownFlags) + stateSourceObject.Load(9, &e.state) + stateSourceObject.Load(10, &e.ttl) + stateSourceObject.Load(11, &e.owner) + stateSourceObject.Load(12, &e.ops) + stateSourceObject.Load(13, &e.frozen) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (l *icmpPacketList) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacketList" +} + +func (l *icmpPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *icmpPacketList) beforeSave() {} + +// +checklocksignore +func (l *icmpPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *icmpPacketList) afterLoad() {} + +// +checklocksignore +func (l *icmpPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *icmpPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/icmp.icmpPacketEntry" +} + +func (e *icmpPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *icmpPacketEntry) beforeSave() {} + +// +checklocksignore +func (e *icmpPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *icmpPacketEntry) afterLoad() {} + +// +checklocksignore +func (e *icmpPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*icmpPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*icmpPacketList)(nil)) + state.Register((*icmpPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD deleted file mode 100644 index b989b1209..000000000 --- a/pkg/tcpip/transport/packet/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "packet_list", - out = "packet_list.go", - package = "packet", - prefix = "packet", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*packet", - "Linker": "*packet", - }, -) - -go_library( - name = "packet", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go new file mode 100644 index 000000000..2c983aad0 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_list.go @@ -0,0 +1,221 @@ +package packet + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type packetElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (packetElementMapper) linkerFor(elem *packet) *packet { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type packetList struct { + head *packet + tail *packet +} + +// Reset resets list l to the empty state. +func (l *packetList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *packetList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *packetList) Front() *packet { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *packetList) Back() *packet { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *packetList) Len() (count int) { + for e := l.Front(); e != nil; e = (packetElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *packetList) PushFront(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + packetElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *packetList) PushBack(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *packetList) PushBackList(m *packetList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(m.head) + packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *packetList) InsertAfter(b, e *packet) { + bLinker := packetElementMapper{}.linkerFor(b) + eLinker := packetElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + packetElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *packetList) InsertBefore(a, e *packet) { + aLinker := packetElementMapper{}.linkerFor(a) + eLinker := packetElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + packetElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *packetList) Remove(e *packet) { + linker := packetElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + packetElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + packetElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type packetEntry struct { + next *packet + prev *packet +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *packetEntry) Next() *packet { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *packetEntry) Prev() *packet { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *packetEntry) SetNext(elem *packet) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *packetEntry) SetPrev(elem *packet) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go new file mode 100644 index 000000000..b354c87b1 --- /dev/null +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -0,0 +1,170 @@ +// automatically generated by stateify. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *packet) StateTypeName() string { + return "pkg/tcpip/transport/packet.packet" +} + +func (p *packet) StateFields() []string { + return []string{ + "packetEntry", + "data", + "timestampNS", + "senderAddr", + "packetInfo", + } +} + +func (p *packet) beforeSave() {} + +// +checklocksignore +func (p *packet) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView = p.saveData() + stateSinkObject.SaveValue(1, dataValue) + stateSinkObject.Save(0, &p.packetEntry) + stateSinkObject.Save(2, &p.timestampNS) + stateSinkObject.Save(3, &p.senderAddr) + stateSinkObject.Save(4, &p.packetInfo) +} + +func (p *packet) afterLoad() {} + +// +checklocksignore +func (p *packet) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.packetEntry) + stateSourceObject.Load(2, &p.timestampNS) + stateSourceObject.Load(3, &p.senderAddr) + stateSourceObject.Load(4, &p.packetInfo) + stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) +} + +func (ep *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/packet.endpoint" +} + +func (ep *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "netProto", + "waiterQueue", + "cooked", + "rcvList", + "rcvBufSize", + "rcvClosed", + "closed", + "bound", + "boundNIC", + "lastError", + "ops", + "frozen", + } +} + +// +checklocksignore +func (ep *endpoint) StateSave(stateSinkObject state.Sink) { + ep.beforeSave() + stateSinkObject.Save(0, &ep.TransportEndpointInfo) + stateSinkObject.Save(1, &ep.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &ep.netProto) + stateSinkObject.Save(3, &ep.waiterQueue) + stateSinkObject.Save(4, &ep.cooked) + stateSinkObject.Save(5, &ep.rcvList) + stateSinkObject.Save(6, &ep.rcvBufSize) + stateSinkObject.Save(7, &ep.rcvClosed) + stateSinkObject.Save(8, &ep.closed) + stateSinkObject.Save(9, &ep.bound) + stateSinkObject.Save(10, &ep.boundNIC) + stateSinkObject.Save(11, &ep.lastError) + stateSinkObject.Save(12, &ep.ops) + stateSinkObject.Save(13, &ep.frozen) +} + +// +checklocksignore +func (ep *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &ep.TransportEndpointInfo) + stateSourceObject.Load(1, &ep.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &ep.netProto) + stateSourceObject.Load(3, &ep.waiterQueue) + stateSourceObject.Load(4, &ep.cooked) + stateSourceObject.Load(5, &ep.rcvList) + stateSourceObject.Load(6, &ep.rcvBufSize) + stateSourceObject.Load(7, &ep.rcvClosed) + stateSourceObject.Load(8, &ep.closed) + stateSourceObject.Load(9, &ep.bound) + stateSourceObject.Load(10, &ep.boundNIC) + stateSourceObject.Load(11, &ep.lastError) + stateSourceObject.Load(12, &ep.ops) + stateSourceObject.Load(13, &ep.frozen) + stateSourceObject.AfterLoad(ep.afterLoad) +} + +func (l *packetList) StateTypeName() string { + return "pkg/tcpip/transport/packet.packetList" +} + +func (l *packetList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *packetList) beforeSave() {} + +// +checklocksignore +func (l *packetList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *packetList) afterLoad() {} + +// +checklocksignore +func (l *packetList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *packetEntry) StateTypeName() string { + return "pkg/tcpip/transport/packet.packetEntry" +} + +func (e *packetEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *packetEntry) beforeSave() {} + +// +checklocksignore +func (e *packetEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *packetEntry) afterLoad() {} + +// +checklocksignore +func (e *packetEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*packet)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*packetList)(nil)) + state.Register((*packetEntry)(nil)) +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD deleted file mode 100644 index 2eab09088..000000000 --- a/pkg/tcpip/transport/raw/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "raw_packet_list", - out = "raw_packet_list.go", - package = "raw", - prefix = "rawPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*rawPacket", - "Linker": "*rawPacket", - }, -) - -go_library( - name = "raw", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "protocol.go", - "raw_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/packet", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go new file mode 100644 index 000000000..48804ff1b --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_packet_list.go @@ -0,0 +1,221 @@ +package raw + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type rawPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type rawPacketList struct { + head *rawPacket + tail *rawPacket +} + +// Reset resets list l to the empty state. +func (l *rawPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *rawPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *rawPacketList) Front() *rawPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *rawPacketList) Back() *rawPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *rawPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (rawPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *rawPacketList) PushFront(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *rawPacketList) PushBack(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *rawPacketList) PushBackList(m *rawPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *rawPacketList) InsertAfter(b, e *rawPacket) { + bLinker := rawPacketElementMapper{}.linkerFor(b) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + rawPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *rawPacketList) InsertBefore(a, e *rawPacket) { + aLinker := rawPacketElementMapper{}.linkerFor(a) + eLinker := rawPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + rawPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *rawPacketList) Remove(e *rawPacket) { + linker := rawPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + rawPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + rawPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type rawPacketEntry struct { + next *rawPacket + prev *rawPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *rawPacketEntry) Next() *rawPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *rawPacketEntry) Prev() *rawPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *rawPacketEntry) SetNext(elem *rawPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *rawPacketEntry) SetPrev(elem *rawPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go new file mode 100644 index 000000000..2bcd983a2 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -0,0 +1,164 @@ +// automatically generated by stateify. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (p *rawPacket) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacket" +} + +func (p *rawPacket) StateFields() []string { + return []string{ + "rawPacketEntry", + "data", + "timestampNS", + "senderAddr", + } +} + +func (p *rawPacket) beforeSave() {} + +// +checklocksignore +func (p *rawPacket) StateSave(stateSinkObject state.Sink) { + p.beforeSave() + var dataValue buffer.VectorisedView = p.saveData() + stateSinkObject.SaveValue(1, dataValue) + stateSinkObject.Save(0, &p.rawPacketEntry) + stateSinkObject.Save(2, &p.timestampNS) + stateSinkObject.Save(3, &p.senderAddr) +} + +func (p *rawPacket) afterLoad() {} + +// +checklocksignore +func (p *rawPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &p.rawPacketEntry) + stateSourceObject.Load(2, &p.timestampNS) + stateSourceObject.Load(3, &p.senderAddr) + stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/raw.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "associated", + "rcvList", + "rcvBufSize", + "rcvClosed", + "closed", + "connected", + "bound", + "owner", + "ops", + "frozen", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.associated) + stateSinkObject.Save(4, &e.rcvList) + stateSinkObject.Save(5, &e.rcvBufSize) + stateSinkObject.Save(6, &e.rcvClosed) + stateSinkObject.Save(7, &e.closed) + stateSinkObject.Save(8, &e.connected) + stateSinkObject.Save(9, &e.bound) + stateSinkObject.Save(10, &e.owner) + stateSinkObject.Save(11, &e.ops) + stateSinkObject.Save(12, &e.frozen) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.associated) + stateSourceObject.Load(4, &e.rcvList) + stateSourceObject.Load(5, &e.rcvBufSize) + stateSourceObject.Load(6, &e.rcvClosed) + stateSourceObject.Load(7, &e.closed) + stateSourceObject.Load(8, &e.connected) + stateSourceObject.Load(9, &e.bound) + stateSourceObject.Load(10, &e.owner) + stateSourceObject.Load(11, &e.ops) + stateSourceObject.Load(12, &e.frozen) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (l *rawPacketList) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacketList" +} + +func (l *rawPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *rawPacketList) beforeSave() {} + +// +checklocksignore +func (l *rawPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *rawPacketList) afterLoad() {} + +// +checklocksignore +func (l *rawPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *rawPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/raw.rawPacketEntry" +} + +func (e *rawPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *rawPacketEntry) beforeSave() {} + +// +checklocksignore +func (e *rawPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *rawPacketEntry) afterLoad() {} + +// +checklocksignore +func (e *rawPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*rawPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*rawPacketList)(nil)) + state.Register((*rawPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD deleted file mode 100644 index 0f20d3856..000000000 --- a/pkg/tcpip/transport/tcp/BUILD +++ /dev/null @@ -1,140 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test", "more_shards") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "tcp_segment_list", - out = "tcp_segment_list.go", - package = "tcp", - prefix = "segment", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*segment", - "Linker": "*segment", - }, -) - -go_template_instance( - name = "tcp_endpoint_list", - out = "tcp_endpoint_list.go", - package = "tcp", - prefix = "endpoint", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*endpoint", - "Linker": "*endpoint", - }, -) - -go_library( - name = "tcp", - srcs = [ - "accept.go", - "connect.go", - "connect_unsafe.go", - "cubic.go", - "dispatcher.go", - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "rack.go", - "rcv.go", - "rcv_state.go", - "reno.go", - "reno_recovery.go", - "sack.go", - "sack_recovery.go", - "sack_scoreboard.go", - "segment.go", - "segment_heap.go", - "segment_queue.go", - "segment_state.go", - "segment_unsafe.go", - "snd.go", - "snd_state.go", - "tcp_endpoint_list.go", - "tcp_segment_list.go", - "timer.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/rand", - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/hash/jenkins", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/ports", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - "@com_github_google_btree//:go_default_library", - ], -) - -go_test( - name = "tcp_x_test", - size = "medium", - srcs = [ - "dual_stack_test.go", - "sack_scoreboard_test.go", - "tcp_noracedetector_test.go", - "tcp_rack_test.go", - "tcp_sack_test.go", - "tcp_test.go", - "tcp_timestamp_test.go", - ], - shard_count = more_shards, - deps = [ - ":tcp", - "//pkg/rand", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/tcp/testing/context", - "//pkg/test/testutil", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcp_test", - size = "small", - srcs = [ - "segment_test.go", - "timer_test.go", - ], - library = ":tcp", - deps = [ - "//pkg/sleep", - "//pkg/tcpip/buffer", - "//pkg/tcpip/stack", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go deleted file mode 100644 index 5342aacfd..000000000 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ /dev/null @@ -1,650 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -func TestV4MappedConnectOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Start connection attempt, it must fail. - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } -} - -func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.WritableEvents) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetPacket() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv4(t, b, synCheckers...) - - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - )) - checker.IPv4(t, c.GetPacket(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV4MappedConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV4Connect(t, c) -} - -func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { - // Start connection attempt to IPv6 address. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.WritableEvents) - defer c.WQ.EventUnregister(&we) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetV6Packet() - synCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - )) - checker.IPv6(t, b, synCheckers...) - - tcp := header.TCP(header.IPv6(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - iss := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - ackCheckers := append(checkers, checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - )) - checker.IPv6(t, c.GetV6Packet(), ackCheckers...) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestV6Connect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { - c := context.NewWithOpts(t, context.Options{ - EnableV6: true, - MTU: defaultMTU, - }) - defer c.Cleanup() - - // Create a v6 endpoint but don't set the v6-only TCP option. - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to local address. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test the connection request. - testV6Connect(t, c) -} - -func TestV4RefuseOnV6Only(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) -} - -func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the RST reply. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPAckNum(uint32(irs)+1), - ), - ) -} - -func testV4Accept(t *testing.T, c *context.Context) { - c.SetGSOEnabled(true) - defer c.SetGSOEnabled(false) - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv4(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestAddr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) - } - - var r strings.Reader - data := "Don't panic" - r.Reset(data) - nep.Write(&r, tcpip.WriteOptions{}) - b = c.GetPacket() - tcp = header.IPv4(b).Payload() - if string(tcp.Payload()) != data { - t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) - } -} - -func TestV4AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind to v4 mapped wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func TestV6AcceptOnV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - // Bind and listen. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - // Send a SYN request. - irs := seqnum.Value(789) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - checker.IPv6(t, b, - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1), - ), - ) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - var addr tcpip.FullAddress - _, _, err := c.EP.Accept(&addr) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(&addr) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if addr.Addr != context.TestV6Addr { - t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) - } -} - -func TestV4AcceptOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4Accept(t, c) -} - -func testV4ListenClose(t *testing.T, c *context.Context) { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - const n = 32 - - // Start listening. - if err := c.EP.Listen(n); err != nil { - t.Fatalf("Listen failed: %v", err) - } - - irs := seqnum.Value(789) - for i := uint16(0); i < n; i++ { - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + i, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Each of these ACKs will cause a syn-cookie based connection to be - // accepted and delivered to the listening endpoint. - for i := 0; i < n; i++ { - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - } - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - nep, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(10 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - nep.Close() - c.EP.Close() -} - -func TestV4ListenCloseOnV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testV4ListenClose(t, c) -} diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go deleted file mode 100644 index 8a026ec46..000000000 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rcv_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -func TestAcceptable(t *testing.T) { - for _, tt := range []struct { - segSeq seqnum.Value - segLen seqnum.Size - rcvNxt, rcvAcc seqnum.Value - want bool - }{ - // The segment is smaller than the window. - {105, 2, 100, 104, false}, - {105, 2, 101, 105, true}, - {105, 2, 102, 106, true}, - {105, 2, 103, 107, true}, - {105, 2, 104, 108, true}, - {105, 2, 105, 109, true}, - {105, 2, 106, 110, true}, - {105, 2, 107, 111, false}, - - // The segment is larger than the window. - {105, 4, 103, 105, true}, - {105, 4, 104, 106, true}, - {105, 4, 105, 107, true}, - {105, 4, 106, 108, true}, - {105, 4, 107, 109, true}, - {105, 4, 108, 110, true}, - {105, 4, 109, 111, false}, - {105, 4, 110, 112, false}, - - // The segment has no width. - {105, 0, 100, 102, false}, - {105, 0, 101, 103, false}, - {105, 0, 102, 104, false}, - {105, 0, 103, 105, true}, - {105, 0, 104, 106, true}, - {105, 0, 105, 107, true}, - {105, 0, 106, 108, false}, - {105, 0, 107, 109, false}, - - // The receive window has no width. - {105, 2, 103, 103, false}, - {105, 2, 104, 104, false}, - {105, 2, 105, 105, false}, - {105, 2, 106, 106, false}, - {105, 2, 107, 107, false}, - {105, 2, 108, 108, false}, - {105, 2, 109, 109, false}, - } { - if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { - t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) - } - } -} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go deleted file mode 100644 index b4e5ba0df..000000000 --- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -const smss = 1500 - -func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard { - s := tcp.NewSACKScoreboard(smss, iss) - for _, blk := range blocks { - s.Insert(blk) - } - return s -} - -func TestSACKScoreboardIsSACKED(t *testing.T) { - type blockTest struct { - block header.SACKBlock - sacked bool - } - testCases := []struct { - comment string - scoreboardBlocks []header.SACKBlock - blockTests []blockTest - iss seqnum.Value - }{ - { - "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", - []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, - []blockTest{ - {header.SACKBlock{15, 21}, true}, - {header.SACKBlock{200, 201}, false}, - {header.SACKBlock{50, 51}, false}, - {header.SACKBlock{53, 120}, true}, - }, - 0, - }, - { - "Test disjoint SACKBlocks", - []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, - []blockTest{ - {header.SACKBlock{2288624809, 2288810057}, true}, - {header.SACKBlock{2288811477, 2288838565}, true}, - {header.SACKBlock{2288810057, 2288811477}, false}, - }, - 2288624809, - }, - { - "Test sequence number wrap around", - []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, - []blockTest{ - {header.SACKBlock{4294254144, 4294254145}, true}, - {header.SACKBlock{4294254143, 4294254144}, false}, - {header.SACKBlock{4294254144, 1}, true}, - {header.SACKBlock{225652, 5350509}, false}, - {header.SACKBlock{5340409, 5350509}, true}, - {header.SACKBlock{5350509, 5350609}, false}, - }, - 4294254144, - }, - { - "Test disjoint SACKBlocks out of order", - []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, - []blockTest{ - {header.SACKBlock{827426028, 827428867}, true}, - {header.SACKBlock{827450168, 827450275}, false}, - }, - 827426000, - }, - } - for _, tc := range testCases { - sb := initScoreboard(tc.scoreboardBlocks, tc.iss) - for _, blkTest := range tc.blockTests { - if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { - t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) - } - } - } -} - -func TestSACKScoreboardIsRangeLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - block header.SACKBlock - lost bool - }{ - // Block not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number covered by this block. - {block: header.SACKBlock{0, 1}, lost: true}, - - // These blocks have all been SACKed and should not be - // considered lost. - {block: header.SACKBlock{1, 2}, lost: false}, - {block: header.SACKBlock{25, 26}, lost: false}, - {block: header.SACKBlock{1, 45}, lost: false}, - - // Same as the first case above. - {block: header.SACKBlock{50, 51}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{119, 120}, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {block: header.SACKBlock{120, 121}, lost: true}, - - // This block has been SACKed and should not be considered lost. - {block: header.SACKBlock{125, 126}, lost: false}, - - // This block has not been SACKed and there are nDupAckThreshold - // number of SACKed blocks after it. - {block: header.SACKBlock{141, 145}, lost: true}, - - // This block has not been SACKed and there are less than - // nDupAckThreshold SACKed sequences after it. - {block: header.SACKBlock{151, 152}, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { - t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) - } - } -} - -func TestSACKScoreboardIsLost(t *testing.T) { - s := tcp.NewSACKScoreboard(10, 0) - s.Insert(header.SACKBlock{1, 25}) - s.Insert(header.SACKBlock{25, 50}) - s.Insert(header.SACKBlock{51, 100}) - s.Insert(header.SACKBlock{111, 120}) - s.Insert(header.SACKBlock{101, 110}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{121, 141}) - s.Insert(header.SACKBlock{145, 146}) - s.Insert(header.SACKBlock{147, 148}) - s.Insert(header.SACKBlock{149, 150}) - s.Insert(header.SACKBlock{153, 154}) - s.Insert(header.SACKBlock{155, 156}) - testCases := []struct { - seq seqnum.Value - lost bool - }{ - // Sequence number not covered by SACK block and has more than - // nDupAckThreshold discontiguous SACK blocks after it as well - // as (nDupAckThreshold -1) * 10 (smss) bytes that have been - // SACKED above the sequence number. - {seq: 0, lost: true}, - - // These sequence numbers have all been SACKed and should not be - // considered lost. - {seq: 1, lost: false}, - {seq: 25, lost: false}, - {seq: 45, lost: false}, - - // Same as first case above. - {seq: 50, lost: true}, - - // This block has been SACKed and should not be considered lost. - {seq: 119, lost: false}, - - // This one should return true because there are > - // (nDupAckThreshold - 1) * 10 (smss) bytes that have been - // sacked above this sequence number. - {seq: 120, lost: true}, - - // This sequence number has been SACKed and should not be - // considered lost. - {seq: 125, lost: false}, - - // This sequence number has not been SACKed and there are - // nDupAckThreshold number of SACKed blocks after it. - {seq: 141, lost: true}, - - // This sequence number has not been SACKed and there are less - // than nDupAckThreshold SACKed sequences after it. - {seq: 151, lost: false}, - } - for _, tc := range testCases { - if want, got := tc.lost, s.IsLost(tc.seq); got != want { - t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) - } - } -} - -func TestSACKScoreboardDelete(t *testing.T) { - blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} - s := initScoreboard(blocks, 4294254143) - s.Delete(5340408) - if s.Empty() { - t.Fatalf("s.Empty() = true, want false") - } - if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } - s.Delete(5340410) - if s.Empty() { - t.Fatal("s.Empty() = true, want false") - } - newSB := header.SACKBlock{5340410, 5350509} - if !s.IsSACKED(newSB) { - t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) - } - s.Delete(5350509) - lastOctet := header.SACKBlock{5350508, 5350509} - if s.IsSACKED(lastOctet) { - t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) - } - - s.Delete(5350510) - if !s.Empty() { - t.Fatal("s.Empty() = false, want true") - } - if got, want := s.Sacked(), seqnum.Size(0); got != want { - t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go deleted file mode 100644 index 486016fc0..000000000 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -type segmentSizeWants struct { - DataSize int - SegMemSize int -} - -func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeWants) { - t.Helper() - got := segmentSizeWants{ - DataSize: seg.data.Size(), - SegMemSize: seg.segMemSize(), - } - if diff := cmp.Diff(got, want); diff != "" { - t.Errorf("%s differs (-want +got):\n%s", name, diff) - } -} - -func TestSegmentMerge(t *testing.T) { - id := stack.TransportEndpointID{} - seg1 := newOutgoingSegment(id, buffer.NewView(10)) - defer seg1.decRef() - seg2 := newOutgoingSegment(id, buffer.NewView(20)) - defer seg2.decRef() - - checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ - DataSize: 10, - SegMemSize: SegSize + 10, - }) - checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ - DataSize: 20, - SegMemSize: SegSize + 20, - }) - - seg1.merge(seg2) - - checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ - DataSize: 30, - SegMemSize: SegSize + 30, - }) - checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ - DataSize: 0, - SegMemSize: SegSize, - }) -} diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go new file mode 100644 index 000000000..a7dc5df81 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go @@ -0,0 +1,221 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type endpointElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type endpointList struct { + head *endpoint + tail *endpoint +} + +// Reset resets list l to the empty state. +func (l *endpointList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *endpointList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *endpointList) Front() *endpoint { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *endpointList) Back() *endpoint { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *endpointList) Len() (count int) { + for e := l.Front(); e != nil; e = (endpointElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *endpointList) PushFront(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + endpointElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *endpointList) PushBack(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *endpointList) PushBackList(m *endpointList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head) + endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *endpointList) InsertAfter(b, e *endpoint) { + bLinker := endpointElementMapper{}.linkerFor(b) + eLinker := endpointElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + endpointElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *endpointList) InsertBefore(a, e *endpoint) { + aLinker := endpointElementMapper{}.linkerFor(a) + eLinker := endpointElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + endpointElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *endpointList) Remove(e *endpoint) { + linker := endpointElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + endpointElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + endpointElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type endpointEntry struct { + next *endpoint + prev *endpoint +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *endpointEntry) Next() *endpoint { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *endpointEntry) Prev() *endpoint { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *endpointEntry) SetNext(elem *endpoint) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *endpointEntry) SetPrev(elem *endpoint) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go deleted file mode 100644 index ced3a9c58..000000000 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ /dev/null @@ -1,558 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// These tests are flaky when run under the go race detector due to some -// iterations taking long enough that the retransmit timer can kick in causing -// the congestion window measurements to fail due to extra packets etc. -// -// +build !race - -package tcp_test - -import ( - "bytes" - "fmt" - "math" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -func TestFastRecovery(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - for i := 0; i < 3; i++ { - c.SendAck(790, rtxOffset) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Wait before checking metrics. - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Now send 7 mode duplicate acks. Each of these should cause a window - // inflation by 1 and cause the sender to send an extra packet. - for i := 0; i < 7; i++ { - c.SendAck(790, rtxOffset) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the retransmit due to partial ack. - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Wait before checking metrics. - metricPollFn = func() error { - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { - return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Receive the 10 extra packets that should have been released due to - // the congestion window inflation in recovery. - for i := 0; i < 10; i++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // A partial ACK during recovery should reduce congestion window by the - // number acked. Since we had "expected" packets outstanding before sending - // partial ack and we acked expected/2 , the cwnd and outstanding should - // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered - // fast recovery). Which means the sender should not send any more packets - // till we ack this one. - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", - 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(790, recover) - - // At this point, the cwnd should reset to expected/2 and there are 10 - // packets outstanding. - // - // NOTE: Technically netstack is incorrect in that we adjust the cwnd on - // the same segment that takes us out of recovery. But because of that - // the actual cwnd at exit of recovery will be expected/2 + 1 as we - // acked a cwnd worth of packets which will increase the cwnd further by - // 1 in congestion avoidance. - // - // Now in the first iteration since there are 10 packets outstanding. - // We would expect to get expected/2 +1 - 10 packets. But subsequent - // iterations will send us expected/2 + 1 + 1 (per iteration). - expected = expected/2 + 1 - 10 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 10 - } - expected++ - } -} - -func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // Double the number of expected packets for the next iteration. - expected *= 2 - } -} - -func TestCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd/2. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected/2. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 (which "consumes" expected/2-1 of the - // acknowledgements), then the congestion avoidance part will consume - // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack - // remains in the "ack count" (which will cause cwnd to be incremented - // once it reaches cwnd acks). - // - // So we're straight into congestion avoidance with cwnd set to - // expected/2 + 1. - // - // Check that packets trains of cwnd packets are sent, and that cwnd is - // incremented by 1 after we acknowledge each packet. - expected = expected/2 + 1 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - expected++ - } -} - -// cubicCwnd returns an estimate of a cubic window given the -// originalCwnd, wMax, last congestion event time and sRTT. -func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { - cwnd := float64(origCwnd) - // We wait 50ms between each iteration so sRTT as computed by cubic - // should be close to 50ms. - elapsed := (time.Since(congEventTime) + sRTT).Seconds() - k := math.Cbrt(float64(wMax) * 0.3 / 0.7) - wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) - cwnd += (wtRTT - cwnd) / cwnd - return int(cwnd) -} - -func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - enableCUBIC(t, c) - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) - } - - // Don't acknowledge the first packet of the last packet train. Let's - // wait for them to time out, which will trigger a restart of slow - // start, and initialization of ssthresh to cwnd * 0.7. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - // Acknowledge all pending data. - c.SendAck(790, bytesRead) - - // Store away the time we sent the ACK and assuming a 200ms RTO - // we estimate that the sender will have an RTO 200ms from now - // and go back into slow start. - packetDropTime := time.Now().Add(200 * time.Millisecond) - - // This part is tricky: when the timeout happened, we had "expected" - // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. - // By acknowledging "expected" packets, the slow-start part will - // increase cwnd to expected/2 essentially putting the connection - // straight into congestion avoidance. - wMax := expected - // Lower expected as per cubic spec after a congestion event. - expected = int(float64(expected) * 0.7) - cwnd := expected - for i := 0; i < iterations; i++ { - // Cubic grows window independent of ACKs. Cubic Window growth - // is a function of time elapsed since last congestion event. - // As a result the congestion window does not grow - // deterministically in response to ACKs. - // - // We need to roughly estimate what the cwnd of the sender is - // based on when we sent the dupacks. - cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) - - packetsExpected := cwnd - for j := 0; j < packetsExpected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - t.Logf("expected packets received, next trying to receive any extra packets that may come") - - // If our estimate was correct there should be no more pending packets. - // We attempt to read a packet a few times with a short sleep in between - // to ensure that we don't see the sender send any unexpected packets. - unexpectedPackets := 0 - for { - gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) - if !gotPacket { - break - } - bytesRead += maxPayload - unexpectedPackets++ - time.Sleep(1 * time.Millisecond) - } - if unexpectedPackets != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(790, bytesRead) - } -} - -func TestRetransmit(t *testing.T) { - maxPayload := 32 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - - const iterations = 3 - data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in two shots. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data[:len(data)/2]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - r.Reset(data[len(data)/2:]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(790, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Wait for a timeout and retransmit. - rtxOffset := bytesRead - maxPayload*expected - c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) - - metricPollFn := func() error { - if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { - return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) - } - - return nil - } - - // Poll when checking metrics. - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Acknowledge half of the pending data. - rtxOffset = bytesRead - expected*maxPayload/2 - c.SendAck(790, rtxOffset) - - // Receive the remaining data, making sure that acknowledged data is not - // retransmitted. - for offset := rtxOffset; offset < len(data); offset += maxPayload { - c.ReceiveAndCheckPacket(data, offset, maxPayload) - c.SendAck(790, offset+maxPayload) - } - - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) -} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go deleted file mode 100644 index 29af87b7d..000000000 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ /dev/null @@ -1,1061 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -const ( - maxPayload = 10 - tsOptionSize = 12 - maxTCPOptionSize = 40 - mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload - latency = 5 * time.Millisecond -) - -func setStackRACKPermitted(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.TCPRACKLossDetection - if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) - } -} - -// TestRACKUpdate tests the RACK related fields are updated when an ACK is -// received on a SACK enabled connection. -func TestRACKUpdate(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - var xmitTime time.Time - probeDone := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint Sender.RACKState is what we expect. - if state.Sender.RACKState.XmitTime.Before(xmitTime) { - t.Fatalf("RACK transmit time failed to update when an ACK is received") - } - - gotSeq := state.Sender.RACKState.EndSequence - wantSeq := state.Sender.SndNxt - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } - - if state.Sender.RACKState.RTT == 0 { - t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") - } - close(probeDone) - }) - setStackSACKPermitted(t, c, true) - setStackRACKPermitted(t, c) - createConnectedWithSACKAndTS(c) - - data := make([]byte, maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - xmitTime = time.Now() - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -// TestRACKDetectReorder tests that RACK detects packet reordering. -func TestRACKDetectReorder(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - t.Skipf("Skipping this test as reorder detection does not consider DSACK.") - - var n int - const ackNumToVerify = 2 - probeDone := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - gotSeq := state.Sender.RACKState.FACK - wantSeq := state.Sender.SndNxt - // FACK should be updated to the highest ending sequence number of the - // segment acknowledged most recently. - if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { - t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq) - } - - n++ - if n < ackNumToVerify { - if state.Sender.RACKState.Reord { - t.Fatalf("RACK reorder detected when there is no reordering") - } - return - } - - if state.Sender.RACKState.Reord == false { - t.Fatalf("RACK reorder detection failed") - } - close(probeDone) - }) - setStackSACKPermitted(t, c, true) - setStackRACKPermitted(t, c) - createConnectedWithSACKAndTS(c) - data := make([]byte, ackNumToVerify*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < ackNumToVerify; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - start := c.IRS.Add(maxPayload + 1) - end := start.Add(maxPayload) - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte { - setStackSACKPermitted(t, c, true) - if enableRACK { - setStackRACKPermitted(t, c) - } - createConnectedWithSACKAndTS(c) - - data := make([]byte, numPackets*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < numPackets; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - // This delay is added to increase RTT as low RTT can cause TLP - // before sending ACK. - time.Sleep(latency) - } - - return data -} - -const ( - validDSACKDetected = 1 - failedToDetectDSACK = 2 - invalidDSACKDetected = 3 -) - -func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) { - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK detects DSACK. - n++ - if n < numACK { - if state.Sender.RACKState.DSACKSeen { - probeDone <- invalidDSACKDetected - } - return - } - - if !state.Sender.RACKState.DSACKSeen { - probeDone <- failedToDetectDSACK - return - } - probeDone <- validDSACKDetected - }) -} - -// TestRACKTLPRecovery tests that RACK sends a tail loss probe (TLP) in the -// case of a tail loss. This simulates a situation where the TLP is able to -// insinuate the SACK holes and sender is able to retransmit the rest. -func TestRACKTLPRecovery(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - - // Send the SACK after RTT because RACK RFC states that if the ACK for a - // retransmission arrives before the smoothed RTT then the sender should not - // update RACK state as it could be a spurious inference. - time.Sleep(info.RTT) - - // Okay, let the sender know we got #8 using a SACK block. - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // The sender should be entering RACK based loss-recovery and sending #6 and - // #7 one after another. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += 2 * maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // One fast retransmit after the SACK. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - // Recovery should be SACK recovery. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // Packets 6, 7 and 8 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 3}, - // TLP recovery should have been detected. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 1}, - // No RTOs should have occurred. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKTLPFallbackRTO tests that RACK sends a tail loss probe (TLP) in the -// case of a tail loss. This simulates a situation where either the TLP or its -// ACK is lost. The sender should retransmit when RTO fires. -func TestRACKTLPFallbackRTO(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [6-8] are lost. Send cumulative ACK for [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #8 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - - // Either the TLP or the ACK the receiver sent with SACK blocks was lost. - - // Confirm that RTO fires and retransmits packet #6. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // No fast retransmits happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - // No SACK recovery happened. - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - // TLP was unsuccessful. - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestNoTLPRecoveryOnDSACK tests the scenario where the sender speculates a -// tail loss and sends a TLP. Everything is received and acked. The probe -// segment is DSACKed. No fast recovery should be triggered in this case. -func TestNoTLPRecoveryOnDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-5] are received first. [6-8] took a detour and will take a - // while to arrive. Ack [1-5]. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // The tail loss probe (#8 packet) is received. - c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize) - - // Now that all 8 packets are received + duplicate 8th packet, send ack. - bytesRead += 3 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Wait for RTO and make sure that nothing else is received. - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - if p := c.GetPacketWithTimeout(info.RTO); p != nil { - t.Errorf("received an unexpected packet: %v", p) - } - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Make sure no recovery was entered. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #8 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestNoTLPOnSACK tests the scenario where there is not exactly a tail loss -// due to the presence of multiple SACK holes. In such a scenario, TLP should -// not be sent. -func TestNoTLPOnSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 8 packets. - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-5] and #7 were received. #6 and #8 were dropped. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhEnd := seventhStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhStart, seventhEnd}}) - - // The sender should retransmit #6. If the sender sends a TLP, then #8 will - // received and fail this test. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #6 was retransmitted due to SACK recovery. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #6 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKOnePacketTailLoss tests the trivial case of a tail loss of only one -// packet. The probe should itself repairs the loss instead of having to go -// into any recovery. -func TestRACKOnePacketTailLoss(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - // Send 3 packets. - numPackets := 3 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Packets [1-2] are received. #3 is lost. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 2 * maxPayload - c.SendAck(seq, bytesRead) - - // PTO should fire and send #3 packet as a TLP. - c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize) - bytesRead += maxPayload - c.SendAck(seq, bytesRead) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // #3 was retransmitted as TLP. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, - // RTO should not have fired. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - // Only #3 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.1. -func TestRACKDetectDSACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 2 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 8 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK #8 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload)) - eighthPEnd := eighthPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}}) - - // Expect retransmission of #6 packet after RTO expires. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-8] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 3 * maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of -// order segments. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.2. -func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 2 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 10 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - // Send DSACK block for #6 along with SACK for out of - // order #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a -// duplicate of out of order packet. -// See: https://tools.ietf.org/html/rfc2883#section-4.1.3 -func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 4 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 10 - sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // ACK [1-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - c.SendAck(seq, bytesRead) - - // Send SACK indicating #6 packet is missing and received #7 packet. - offset := seqnum.Size(bytesRead + maxPayload) - start := c.IRS.Add(1 + offset) - end := start.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Send SACK with #6 packet is missing and received [7-8] packets. - end = start.Add(2 * maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Consider #8 packet is duplicated on the network and send DSACK. - dsackStart := c.IRS.Add(1 + offset + maxPayload) - dsackEnd := dsackStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.1. -func TestRACKDetectDSACKSingleDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 4 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 4 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-3] packets and received #4 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // ACK for retransmitted #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by - // sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous -// duplicate subsegments covered by the cumulative acknowledgement. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.2. -func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 5 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 6 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-5] packets and received #6 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Received delayed #2 packet. - bytesRead += maxPayload - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #4 packet. - start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 - // packet by sending DSACK block for #2 packet. - dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not -// covered by the cumulative acknowledgement. -// See: https://tools.ietf.org/html/rfc2883#section-4.2.3. -func TestRACKDetectDSACKDup(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan int) - const ackNumToVerify = 5 - addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) - - numPackets := 7 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed #3 packet. - start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - end1 := start1.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Consider #2 packet has been dropped and SACK #4 packet. - start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) - end2 := start2.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) - - // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 - // packet by sending DSACK block for the subsegment. - dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) - end1 = end1.Add(seqnum.Size(2 * maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - err := <-probeDone - switch err { - case failedToDetectDSACK: - t.Fatalf("RACK DSACK detection failed") - case invalidDSACKDetected: - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } -} - -// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK -// is not the first SACK block. -func TestRACKWithInvalidDSACKBlock(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan struct{}) - const ackNumToVerify = 2 - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK does not detect DSACK when DSACK block is - // not the first SACK block. - n++ - t.Helper() - if state.Sender.RACKState.DSACKSeen { - t.Fatalf("RACK DSACK detected when there is no duplicate SACK") - } - - if n == ackNumToVerify { - close(probeDone) - } - }) - - numPackets := 10 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP). - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - bytesRead := 5 * maxPayload - seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - seventhPEnd := seventhPStart.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}}) - - // Expect retransmission of #6 packet. - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - - // Send DSACK block for #6 packet indicating both - // initial and retransmitted packet are received and - // packets [1-7] are received. - start := c.IRS.Add(1 + seqnum.Size(bytesRead)) - end := start.Add(maxPayload) - bytesRead += 2 * maxPayload - - // Send DSACK block as second block. The first block is a SACK for #9 packet. - start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload) - end1 := start1.Add(maxPayload) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) - - // Wait for the probe function to finish processing the - // ACK before the test completes. - <-probeDone -} - -func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) { - var n int - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that RACK detects DSACK. - n++ - if n < numACK { - return - } - - if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT { - probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT) - return - } - - if state.Sender.RACKState.ReoWndIncr != 1 { - probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr) - return - } - - if state.Sender.RACKState.ReoWndPersist > 0 { - probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist) - return - } - probeDone <- nil - }) -} - -func TestRACKCheckReorderWindow(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan error) - const ackNumToVerify = 3 - addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) - - const numPackets = 7 - sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send ACK for #1 packet. - bytesRead := maxPayload - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, bytesRead) - - // Missing [2-6] packets and SACK #7 packet. - start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - // Received delayed packets [2-6] which indicates there is reordering - // in the connection. - bytesRead += 6 * maxPayload - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - if err := <-probeDone; err != nil { - t.Fatalf("unexpected values for RACK variables: %v", err) - } -} - -func TestRACKWithDuplicateACK(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - const numPackets = 4 - data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */) - - // Send three duplicate ACKs to trigger fast recovery. The first - // segment is considered as lost and will be retransmitted after - // receiving the duplicate ACKs. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(1 + seqnum.Size(maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) - end = end.Add(seqnum.Size(maxPayload)) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK -// is received. -func TestRACKUpdateSackedOut(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - probeDone := make(chan struct{}) - ackNum := 0 - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint Sender.SackedOut is what we expect. - if state.Sender.SackedOut != 2 && ackNum == 0 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) - } - - if state.Sender.SackedOut != 0 && ackNum == 1 { - t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) - } - if ackNum > 0 { - close(probeDone) - } - ackNum++ - }) - - sendAndReceiveWithSACK(t, c, 8, true /* enableRACK */) - - // ACK for [3-5] packets. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) - bytesRead := 2 * maxPayload - end := start.Add(seqnum.Size(bytesRead)) - c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) - - bytesRead += 3 * maxPayload - c.SendAck(seq, bytesRead) - - // Wait for the probe function to finish processing the ACK before the - // test completes. - <-probeDone -} - -// TestRACKWithWindowFull tests that RACK honors the receive window size. -func TestRACKWithWindowFull(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - setStackSACKPermitted(t, c, true) - setStackRACKPermitted(t, c) - createConnectedWithSACKAndTS(c) - - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - const numPkts = 10 - data := make([]byte, numPkts*maxPayload) - for i := range data { - data[i] = byte(i) - } - - // Write the data. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - bytesRead := 0 - for i := 0; i < numPkts; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - if i == 0 { - // Send ACK for the first packet to establish RTT. - c.SendAck(seq, maxPayload) - } - } - - // SACK for #10 packet. - start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload)) - end := start.Add(seqnum.Size(maxPayload)) - c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}}) - - var info tcpip.TCPInfoOption - if err := c.EP.GetSockOpt(&info); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) - } - // Wait for RTT to trigger recovery. - time.Sleep(info.RTT) - - // Expect retransmission of #2 packet. - c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize) - - // Send ACK for #2 packet. - c.SendAck(seq, 3*maxPayload) - - // Expect retransmission of #3 packet. - c.ReceiveAndCheckPacketWithOptions(data, 3*maxPayload, maxPayload, tsOptionSize) - - // Send ACK with zero window size. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + 4*maxPayload), - RcvWnd: 0, - }) - - // No packet should be received as the receive window size is zero. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") -} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go deleted file mode 100644 index 20c9761f2..000000000 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ /dev/null @@ -1,699 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "log" - "reflect" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" -) - -// createConnectedWithSACKPermittedOption creates and connects c.ep with the -// SACKPermitted option enabled if the stack in the context has the SACK support -// enabled. -func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) -} - -// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS -// option enabled if the stack in the context has SACK and TS enabled. -func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) -} - -func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { - t.Helper() - opt := tcpip.TCPSACKEnabled(enable) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } -} - -// TestSackPermittedConnect establishes a connection with the SACK option -// enabled. -func TestSackPermittedConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - rep := createConnectedWithSACKPermittedOption(c) - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // Restore the saved sequence number so that the - // VerifyXXX calls use the right sequence number for - // checking ACK numbers. - rep.NextSeqNum = savedSeqNum - if sackEnabled { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackDisabledConnect establishes a connection with the SACK option -// disabled and verifies that no SACKs are sent for out of order segments. -func TestSackDisabledConnect(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) - - data := []byte{1, 2, 3} - - rep.SendPacket(data, nil) - savedSeqNum := rep.NextSeqNum - rep.VerifyACKNoSACK() - - // Make an out of order packet and send it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older sequence number and - // no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative ACK for all 9 - // bytes sent and no SACK blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned in the ACK. - rep.VerifyACKNoSACK() - }) - } -} - -// TestSackPermittedAccept accepts and establishes a connection with the -// SACKPermitted option enabled if the connection request specifies the -// SACKPermitted option. In case of SYN cookies SACK should be disabled as we -// don't encode the SACK information in the cookie. -func TestSackPermittedAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - sackPermitted bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if tc.cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - sackBlocks := []header.SACKBlock{ - {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, - } - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number. - rep.NextSeqNum = savedSeqNum - if sackEnabled && tc.sackPermitted { - rep.VerifyACKHasSACK(sackBlocks) - } else { - rep.VerifyACKNoSACK() - } - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -// TestSackDisabledAccept accepts and establishes a connection with -// the SACKPermitted option disabled and verifies that no SACKs are -// sent for out of order packets. -func TestSackDisabledAccept(t *testing.T) { - type testCase struct { - cookieEnabled bool - wndScale int - wndSize uint16 - } - - testCases := []testCase{ - // When cookie is used window scaling is disabled. - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { - for _, sackEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if tc.cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - setStackSACKPermitted(t, c, sackEnabled) - - rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now verify no SACK blocks are - // received when sack is disabled. - data := []byte{1, 2, 3} - rep.SendPacket(data, nil) - rep.VerifyACKNoSACK() - savedSeqNum := rep.NextSeqNum - - // Make an out of order packet and send - // it. - rep.NextSeqNum += 3 - rep.SendPacket(data, nil) - - // The ACK should contain the older - // sequence number and no SACK blocks. - rep.NextSeqNum = savedSeqNum - rep.VerifyACKNoSACK() - - // Send the missing segment. - rep.SendPacket(data, nil) - // The ACK should contain the cumulative - // ACK for all 9 bytes sent and no SACK - // blocks. - rep.NextSeqNum += 3 - // Check that no SACK block is returned - // in the ACK. - rep.VerifyACKNoSACK() - }) - } - }) - } -} - -func TestUpdateSACKBlocks(t *testing.T) { - testCases := []struct { - segStart seqnum.Value - segEnd seqnum.Value - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - updated []header.SACKBlock - }{ - // Trivial cases where current SACK block list is empty and we - // have an out of order delivery. - {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, - {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, - {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, - - // Cases where current SACK block list is not empty and we have - // an out of order delivery. Tests that the updated SACK block - // list has the first block as the one that contains the new - // SACK block representing the segment that was just delivered. - {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, - {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, - - // Ensure that we only retain header.MaxSACKBlocks and drop the - // oldest one if adding a new block exceeds - // header.MaxSACKBlocks. - {24, 30, 9, - []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, - []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, - - // Cases where segment extends an existing SACK block. - {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, - {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, - {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, - {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, - {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, - - // Cases where segment contains rcvNxt. - {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, - } - - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { - t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) - } - - } -} - -func TestTrimSackBlockList(t *testing.T) { - testCases := []struct { - rcvNxt seqnum.Value - sackBlocks []header.SACKBlock - trimmed []header.SACKBlock - }{ - // Simple cases where we trim whole entries. - {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, - {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, - {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, - {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - // Cases where we need to update a block. - {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, - {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, - {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, - {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, - } - for _, tc := range testCases { - var sack tcp.SACKInfo - copy(sack.Blocks[:], tc.sackBlocks) - sack.NumBlocks = len(tc.sackBlocks) - tcp.TrimSACKBlockList(&sack, tc.rcvNxt) - if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { - t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) - } - } -} - -func TestSACKRecovery(t *testing.T) { - const maxPayload = 10 - // See: tcp.makeOptions for why tsOptionSize is set to 12 here. - const tsOptionSize = 12 - // Enabling SACK means the payload size is reduced to account - // for the extra space required for the TCP options. - // - // We increase the MTU by 40 bytes to account for SACK and Timestamp - // options. - const maxTCPOptionSize = 40 - - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) - defer c.Cleanup() - - c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { - // We use log.Printf instead of t.Logf here because this probe - // can fire even when the test function has finished. This is - // because closing the endpoint in cleanup() does not mean the - // actual worker loop terminates immediately as it still has to - // do a full TCP shutdown. But this test can finish running - // before the shutdown is done. Using t.Logf in such a case - // causes the test to panic due to logging after test finished. - log.Printf("state: %+v\n", s) - }) - setStackSACKPermitted(t, c, true) - createConnectedWithSACKAndTS(c) - - const iterations = 3 - data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) - for i := range data { - data[i] = byte(i) - } - - // Write all the data in one shot. Packets will only be written at the - // MTU size though. - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Do slow start for a few iterations. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - expected := tcp.InitialCwnd - bytesRead := 0 - for i := 0; i < iterations; i++ { - expected = tcp.InitialCwnd << uint(i) - if i > 0 { - // Acknowledge all the data received so far if not on - // first iteration. - c.SendAck(seq, bytesRead) - } - - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) - } - - // Send 3 duplicate acks. This should force an immediate retransmit of - // the pending packet and put the sender into fast recovery. - rtxOffset := bytesRead - maxPayload*expected - start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end := start.Add(10) - for i := 0; i < 3; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - // Receive the retransmitted packet. - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause - // window inflation and sending of packets is completely handled by the - // SACK Recovery algorithm. We should see no packets being released, as - // the cwnd at this point after entering recovery should be half of the - // outstanding number of packets in flight. - for i := 0; i < 7; i++ { - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - end = end.Add(10) - } - - recover := bytesRead - - // Ensure no new packets arrive. - c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", - 50*time.Millisecond) - - // Acknowledge half of the pending data. This along with the 10 sacked - // segments above should reduce the outstanding below the current - // congestion window allowing the sender to transmit data. - rtxOffset = bytesRead - expected*maxPayload/2 - - // Now send a partial ACK w/ a SACK block that indicates that the next 3 - // segments are lost and we have received 6 segments after the lost - // segments. This should cause the sender to immediately transmit all 3 - // segments in response to this ACK unlike in FastRecovery where only 1 - // segment is retransmitted per ACK. - start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) - end = start.Add(60) - c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}}) - - // At this point, we acked expected/2 packets and we SACKED 6 packets and - // 3 segments were considered lost due to the SACK block we sent. - // - // So total packets outstanding can be calculated as follows after 7 - // iterations of slow start -> 10/20/40/80/160/320/640. So expected - // should be 640 at start, then we went to recover at which point the - // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the - // network). - // Outstanding at this point after acking half the window - // (320 packets) will be: - // outstanding = 640-320-6(due to SACK block)-3 = 311 - // - // The last 3 is due to the fact that the first 3 packets after - // rtxOffset will be considered lost due to the SACK blocks sent. - // Receive the retransmit due to partial ack. - - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) - // Receive the 2 extra packets that should have been retransmitted as - // those should be considered lost and immediately retransmitted based - // on the SACK information in the previous ACK sent above. - for i := 0; i < 2; i++ { - c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) - } - - // Now we should get 9 more new unsent packets as the cwnd is 323 and - // outstanding is 311. - for i := 0; i < 9; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - - metricPollFn = func() error { - // In SACK recovery only the first segment is fast retransmitted when - // entering recovery. - if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { - return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) - } - - if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) - } - - if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { - return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) - - // Acknowledge all pending data to recover point. - c.SendAck(seq, recover) - - // At this point, the cwnd should reset to expected/2 and there are 9 - // packets outstanding. - // - // Now in the first iteration since there are 9 packets outstanding. - // We would expect to get expected/2 - 9 packets. But subsequent - // iterations will send us expected/2 + 1 (per iteration). - expected = expected/2 - 9 - for i := 0; i < iterations; i++ { - // Read all packets expected on this iteration. Don't - // acknowledge any of them just yet, so that we can measure the - // congestion window. - for j := 0; j < expected; j++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - bytesRead += maxPayload - } - // Check we don't receive any more packets on this iteration. - // The timeout can't be too high or we'll trigger a timeout. - c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) - - // Acknowledge all the data received so far. - c.SendAck(seq, bytesRead) - - // In cogestion avoidance, the packets trains increase by 1 in - // each iteration. - if i == 0 { - // After the first iteration we expect to get the full - // congestion window worth of packets in every - // iteration. - expected += 9 - } - expected++ - } -} - -// TestRecoveryEntry tests the following two properties of entering recovery: -// - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK -// scoreboard but dupack count is still below threshold. -// - Only enter recovery when at least one more byte of data beyond the highest -// byte that was outstanding when fast retransmit was last entered is acked. -func TestRecoveryEntry(t *testing.T) { - c := context.New(t, uint32(mtu)) - defer c.Cleanup() - - numPackets := 5 - data := sendAndReceiveWithSACK(t, c, numPackets, false /* enableRACK */) - - // Ack #1 packet. - seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendAck(seq, maxPayload) - - // Now SACK #3, #4 and #5 packets. This will simulate a situation where - // SND.UNA should be considered lost and the sender should enter fast recovery - // (even though dupack count is still below threshold). - p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) - p3End := p3Start.Add(maxPayload) - p4Start := p3End - p4End := p4Start.Add(maxPayload) - p5Start := p4End - p5End := p5Start.Add(maxPayload) - c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}}) - - // Expect #2 to be retransmitted. - c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize) - - metricPollFn := func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 was retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, - // No RTOs should have fired yet. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } - - // Send 4 more packets. - var r bytes.Reader - data = append(data, data...) - r.Reset(data[5*maxPayload : 9*maxPayload]) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - var sackBlocks []header.SACKBlock - bytesRead := numPackets * maxPayload - for i := 0; i < 4; i++ { - c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) - if i > 0 { - pStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) - sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)}) - c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks) - } - bytesRead += maxPayload - } - - // #6 should be retransmitted after RTO. The sender should NOT enter fast - // recovery because the highest byte that was outstanding when fast recovery - // was last entered is #5 packet's end. And the sender requires at least one - // more byte beyond that (#6 packet start) to be acked to enter recovery. - c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize) - c.SendAck(seq, 9*maxPayload) - - metricPollFn = func() error { - tcpStats := c.Stack().Stats().TCP - stats := []struct { - stat *tcpip.StatCounter - name string - want uint64 - }{ - // Only 1 SACK recovery must have happened. - {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, - // #2 and #6 were retransmitted. - {tcpStats.Retransmits, "stats.TCP.Retransmits", 2}, - // RTO should have fired once. - {tcpStats.Timeouts, "stats.TCP.Timeouts", 1}, - } - for _, s := range stats { - if got, want := s.stat.Value(), s.want; got != want { - return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) - } - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go new file mode 100644 index 000000000..a14cff27e --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,221 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *segmentList) Back() *segment { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *segmentList) Len() (count int) { + for e := l.Front(); e != nil; e = (segmentElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *segmentList) PushFront(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *segmentList) PushBack(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *segmentList) InsertAfter(b, e *segment) { + bLinker := segmentElementMapper{}.linkerFor(b) + eLinker := segmentElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *segmentList) InsertBefore(a, e *segment) { + aLinker := segmentElementMapper{}.linkerFor(a) + eLinker := segmentElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *segmentList) Remove(e *segment) { + linker := segmentElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go new file mode 100644 index 000000000..139b1c796 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -0,0 +1,934 @@ +// automatically generated by stateify. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (c *cubicState) StateTypeName() string { + return "pkg/tcpip/transport/tcp.cubicState" +} + +func (c *cubicState) StateFields() []string { + return []string{ + "TCPCubicState", + "numCongestionEvents", + "s", + } +} + +func (c *cubicState) beforeSave() {} + +// +checklocksignore +func (c *cubicState) StateSave(stateSinkObject state.Sink) { + c.beforeSave() + stateSinkObject.Save(0, &c.TCPCubicState) + stateSinkObject.Save(1, &c.numCongestionEvents) + stateSinkObject.Save(2, &c.s) +} + +func (c *cubicState) afterLoad() {} + +// +checklocksignore +func (c *cubicState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &c.TCPCubicState) + stateSourceObject.Load(1, &c.numCongestionEvents) + stateSourceObject.Load(2, &c.s) +} + +func (s *SACKInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.SACKInfo" +} + +func (s *SACKInfo) StateFields() []string { + return []string{ + "Blocks", + "NumBlocks", + } +} + +func (s *SACKInfo) beforeSave() {} + +// +checklocksignore +func (s *SACKInfo) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.Blocks) + stateSinkObject.Save(1, &s.NumBlocks) +} + +func (s *SACKInfo) afterLoad() {} + +// +checklocksignore +func (s *SACKInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.Blocks) + stateSourceObject.Load(1, &s.NumBlocks) +} + +func (s *sndQueueInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sndQueueInfo" +} + +func (s *sndQueueInfo) StateFields() []string { + return []string{ + "TCPSndBufState", + "sndQueue", + } +} + +func (s *sndQueueInfo) beforeSave() {} + +// +checklocksignore +func (s *sndQueueInfo) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.TCPSndBufState) + stateSinkObject.Save(1, &s.sndQueue) +} + +func (s *sndQueueInfo) afterLoad() {} + +// +checklocksignore +func (s *sndQueueInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.TCPSndBufState) + stateSourceObject.LoadWait(1, &s.sndQueue) +} + +func (r *rcvQueueInfo) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rcvQueueInfo" +} + +func (r *rcvQueueInfo) StateFields() []string { + return []string{ + "TCPRcvBufState", + "rcvQueue", + } +} + +func (r *rcvQueueInfo) beforeSave() {} + +// +checklocksignore +func (r *rcvQueueInfo) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.TCPRcvBufState) + stateSinkObject.Save(1, &r.rcvQueue) +} + +func (r *rcvQueueInfo) afterLoad() {} + +// +checklocksignore +func (r *rcvQueueInfo) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.TCPRcvBufState) + stateSourceObject.LoadWait(1, &r.rcvQueue) +} + +func (a *accepted) StateTypeName() string { + return "pkg/tcpip/transport/tcp.accepted" +} + +func (a *accepted) StateFields() []string { + return []string{ + "endpoints", + "cap", + } +} + +func (a *accepted) beforeSave() {} + +// +checklocksignore +func (a *accepted) StateSave(stateSinkObject state.Sink) { + a.beforeSave() + var endpointsValue []*endpoint = a.saveEndpoints() + stateSinkObject.SaveValue(0, endpointsValue) + stateSinkObject.Save(1, &a.cap) +} + +func (a *accepted) afterLoad() {} + +// +checklocksignore +func (a *accepted) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(1, &a.cap) + stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TCPEndpointStateInner", + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "hardError", + "lastError", + "rcvQueueInfo", + "rcvMemUsed", + "ownedByUser", + "state", + "boundNICID", + "ttl", + "isConnectNotified", + "portFlags", + "boundBindToDevice", + "boundPortFlags", + "boundDest", + "effectiveNetProtos", + "workerRunning", + "workerCleanup", + "recentTSTime", + "shutdownFlags", + "tcpRecovery", + "sack", + "delay", + "scoreboard", + "segmentQueue", + "synRcvdCount", + "userMSS", + "maxSynRetries", + "windowClamp", + "sndQueueInfo", + "cc", + "keepalive", + "userTimeout", + "deferAccept", + "accepted", + "rcv", + "snd", + "connectingAddress", + "amss", + "sendTOS", + "gso", + "tcpLingerTimeout", + "closed", + "txHash", + "owner", + "ops", + "lastOutOfWindowAckTime", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + var stateValue EndpointState = e.saveState() + stateSinkObject.SaveValue(10, stateValue) + var recentTSTimeValue unixTime = e.saveRecentTSTime() + stateSinkObject.SaveValue(21, recentTSTimeValue) + var lastOutOfWindowAckTimeValue unixTime = e.saveLastOutOfWindowAckTime() + stateSinkObject.SaveValue(49, lastOutOfWindowAckTimeValue) + stateSinkObject.Save(0, &e.TCPEndpointStateInner) + stateSinkObject.Save(1, &e.TransportEndpointInfo) + stateSinkObject.Save(2, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(3, &e.waiterQueue) + stateSinkObject.Save(4, &e.uniqueID) + stateSinkObject.Save(5, &e.hardError) + stateSinkObject.Save(6, &e.lastError) + stateSinkObject.Save(7, &e.rcvQueueInfo) + stateSinkObject.Save(8, &e.rcvMemUsed) + stateSinkObject.Save(9, &e.ownedByUser) + stateSinkObject.Save(11, &e.boundNICID) + stateSinkObject.Save(12, &e.ttl) + stateSinkObject.Save(13, &e.isConnectNotified) + stateSinkObject.Save(14, &e.portFlags) + stateSinkObject.Save(15, &e.boundBindToDevice) + stateSinkObject.Save(16, &e.boundPortFlags) + stateSinkObject.Save(17, &e.boundDest) + stateSinkObject.Save(18, &e.effectiveNetProtos) + stateSinkObject.Save(19, &e.workerRunning) + stateSinkObject.Save(20, &e.workerCleanup) + stateSinkObject.Save(22, &e.shutdownFlags) + stateSinkObject.Save(23, &e.tcpRecovery) + stateSinkObject.Save(24, &e.sack) + stateSinkObject.Save(25, &e.delay) + stateSinkObject.Save(26, &e.scoreboard) + stateSinkObject.Save(27, &e.segmentQueue) + stateSinkObject.Save(28, &e.synRcvdCount) + stateSinkObject.Save(29, &e.userMSS) + stateSinkObject.Save(30, &e.maxSynRetries) + stateSinkObject.Save(31, &e.windowClamp) + stateSinkObject.Save(32, &e.sndQueueInfo) + stateSinkObject.Save(33, &e.cc) + stateSinkObject.Save(34, &e.keepalive) + stateSinkObject.Save(35, &e.userTimeout) + stateSinkObject.Save(36, &e.deferAccept) + stateSinkObject.Save(37, &e.accepted) + stateSinkObject.Save(38, &e.rcv) + stateSinkObject.Save(39, &e.snd) + stateSinkObject.Save(40, &e.connectingAddress) + stateSinkObject.Save(41, &e.amss) + stateSinkObject.Save(42, &e.sendTOS) + stateSinkObject.Save(43, &e.gso) + stateSinkObject.Save(44, &e.tcpLingerTimeout) + stateSinkObject.Save(45, &e.closed) + stateSinkObject.Save(46, &e.txHash) + stateSinkObject.Save(47, &e.owner) + stateSinkObject.Save(48, &e.ops) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TCPEndpointStateInner) + stateSourceObject.Load(1, &e.TransportEndpointInfo) + stateSourceObject.Load(2, &e.DefaultSocketOptionsHandler) + stateSourceObject.LoadWait(3, &e.waiterQueue) + stateSourceObject.Load(4, &e.uniqueID) + stateSourceObject.Load(5, &e.hardError) + stateSourceObject.Load(6, &e.lastError) + stateSourceObject.Load(7, &e.rcvQueueInfo) + stateSourceObject.Load(8, &e.rcvMemUsed) + stateSourceObject.Load(9, &e.ownedByUser) + stateSourceObject.Load(11, &e.boundNICID) + stateSourceObject.Load(12, &e.ttl) + stateSourceObject.Load(13, &e.isConnectNotified) + stateSourceObject.Load(14, &e.portFlags) + stateSourceObject.Load(15, &e.boundBindToDevice) + stateSourceObject.Load(16, &e.boundPortFlags) + stateSourceObject.Load(17, &e.boundDest) + stateSourceObject.Load(18, &e.effectiveNetProtos) + stateSourceObject.Load(19, &e.workerRunning) + stateSourceObject.Load(20, &e.workerCleanup) + stateSourceObject.Load(22, &e.shutdownFlags) + stateSourceObject.Load(23, &e.tcpRecovery) + stateSourceObject.Load(24, &e.sack) + stateSourceObject.Load(25, &e.delay) + stateSourceObject.Load(26, &e.scoreboard) + stateSourceObject.LoadWait(27, &e.segmentQueue) + stateSourceObject.Load(28, &e.synRcvdCount) + stateSourceObject.Load(29, &e.userMSS) + stateSourceObject.Load(30, &e.maxSynRetries) + stateSourceObject.Load(31, &e.windowClamp) + stateSourceObject.Load(32, &e.sndQueueInfo) + stateSourceObject.Load(33, &e.cc) + stateSourceObject.Load(34, &e.keepalive) + stateSourceObject.Load(35, &e.userTimeout) + stateSourceObject.Load(36, &e.deferAccept) + stateSourceObject.Load(37, &e.accepted) + stateSourceObject.LoadWait(38, &e.rcv) + stateSourceObject.LoadWait(39, &e.snd) + stateSourceObject.Load(40, &e.connectingAddress) + stateSourceObject.Load(41, &e.amss) + stateSourceObject.Load(42, &e.sendTOS) + stateSourceObject.Load(43, &e.gso) + stateSourceObject.Load(44, &e.tcpLingerTimeout) + stateSourceObject.Load(45, &e.closed) + stateSourceObject.Load(46, &e.txHash) + stateSourceObject.Load(47, &e.owner) + stateSourceObject.Load(48, &e.ops) + stateSourceObject.LoadValue(10, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) + stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { e.loadRecentTSTime(y.(unixTime)) }) + stateSourceObject.LoadValue(49, new(unixTime), func(y interface{}) { e.loadLastOutOfWindowAckTime(y.(unixTime)) }) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (k *keepalive) StateTypeName() string { + return "pkg/tcpip/transport/tcp.keepalive" +} + +func (k *keepalive) StateFields() []string { + return []string{ + "idle", + "interval", + "count", + "unacked", + } +} + +func (k *keepalive) beforeSave() {} + +// +checklocksignore +func (k *keepalive) StateSave(stateSinkObject state.Sink) { + k.beforeSave() + stateSinkObject.Save(0, &k.idle) + stateSinkObject.Save(1, &k.interval) + stateSinkObject.Save(2, &k.count) + stateSinkObject.Save(3, &k.unacked) +} + +func (k *keepalive) afterLoad() {} + +// +checklocksignore +func (k *keepalive) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &k.idle) + stateSourceObject.Load(1, &k.interval) + stateSourceObject.Load(2, &k.count) + stateSourceObject.Load(3, &k.unacked) +} + +func (rc *rackControl) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rackControl" +} + +func (rc *rackControl) StateFields() []string { + return []string{ + "TCPRACKState", + "exitedRecovery", + "minRTT", + "tlpRxtOut", + "tlpHighRxt", + "snd", + } +} + +func (rc *rackControl) beforeSave() {} + +// +checklocksignore +func (rc *rackControl) StateSave(stateSinkObject state.Sink) { + rc.beforeSave() + stateSinkObject.Save(0, &rc.TCPRACKState) + stateSinkObject.Save(1, &rc.exitedRecovery) + stateSinkObject.Save(2, &rc.minRTT) + stateSinkObject.Save(3, &rc.tlpRxtOut) + stateSinkObject.Save(4, &rc.tlpHighRxt) + stateSinkObject.Save(5, &rc.snd) +} + +func (rc *rackControl) afterLoad() {} + +// +checklocksignore +func (rc *rackControl) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rc.TCPRACKState) + stateSourceObject.Load(1, &rc.exitedRecovery) + stateSourceObject.Load(2, &rc.minRTT) + stateSourceObject.Load(3, &rc.tlpRxtOut) + stateSourceObject.Load(4, &rc.tlpHighRxt) + stateSourceObject.Load(5, &rc.snd) +} + +func (r *receiver) StateTypeName() string { + return "pkg/tcpip/transport/tcp.receiver" +} + +func (r *receiver) StateFields() []string { + return []string{ + "TCPReceiverState", + "ep", + "rcvWnd", + "rcvWUP", + "prevBufUsed", + "closed", + "pendingRcvdSegments", + "lastRcvdAckTime", + } +} + +func (r *receiver) beforeSave() {} + +// +checklocksignore +func (r *receiver) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + var lastRcvdAckTimeValue unixTime = r.saveLastRcvdAckTime() + stateSinkObject.SaveValue(7, lastRcvdAckTimeValue) + stateSinkObject.Save(0, &r.TCPReceiverState) + stateSinkObject.Save(1, &r.ep) + stateSinkObject.Save(2, &r.rcvWnd) + stateSinkObject.Save(3, &r.rcvWUP) + stateSinkObject.Save(4, &r.prevBufUsed) + stateSinkObject.Save(5, &r.closed) + stateSinkObject.Save(6, &r.pendingRcvdSegments) +} + +func (r *receiver) afterLoad() {} + +// +checklocksignore +func (r *receiver) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.TCPReceiverState) + stateSourceObject.Load(1, &r.ep) + stateSourceObject.Load(2, &r.rcvWnd) + stateSourceObject.Load(3, &r.rcvWUP) + stateSourceObject.Load(4, &r.prevBufUsed) + stateSourceObject.Load(5, &r.closed) + stateSourceObject.Load(6, &r.pendingRcvdSegments) + stateSourceObject.LoadValue(7, new(unixTime), func(y interface{}) { r.loadLastRcvdAckTime(y.(unixTime)) }) +} + +func (r *renoState) StateTypeName() string { + return "pkg/tcpip/transport/tcp.renoState" +} + +func (r *renoState) StateFields() []string { + return []string{ + "s", + } +} + +func (r *renoState) beforeSave() {} + +// +checklocksignore +func (r *renoState) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.s) +} + +func (r *renoState) afterLoad() {} + +// +checklocksignore +func (r *renoState) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.s) +} + +func (rr *renoRecovery) StateTypeName() string { + return "pkg/tcpip/transport/tcp.renoRecovery" +} + +func (rr *renoRecovery) StateFields() []string { + return []string{ + "s", + } +} + +func (rr *renoRecovery) beforeSave() {} + +// +checklocksignore +func (rr *renoRecovery) StateSave(stateSinkObject state.Sink) { + rr.beforeSave() + stateSinkObject.Save(0, &rr.s) +} + +func (rr *renoRecovery) afterLoad() {} + +// +checklocksignore +func (rr *renoRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &rr.s) +} + +func (sr *sackRecovery) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sackRecovery" +} + +func (sr *sackRecovery) StateFields() []string { + return []string{ + "s", + } +} + +func (sr *sackRecovery) beforeSave() {} + +// +checklocksignore +func (sr *sackRecovery) StateSave(stateSinkObject state.Sink) { + sr.beforeSave() + stateSinkObject.Save(0, &sr.s) +} + +func (sr *sackRecovery) afterLoad() {} + +// +checklocksignore +func (sr *sackRecovery) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &sr.s) +} + +func (s *SACKScoreboard) StateTypeName() string { + return "pkg/tcpip/transport/tcp.SACKScoreboard" +} + +func (s *SACKScoreboard) StateFields() []string { + return []string{ + "smss", + "maxSACKED", + } +} + +func (s *SACKScoreboard) beforeSave() {} + +// +checklocksignore +func (s *SACKScoreboard) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + stateSinkObject.Save(0, &s.smss) + stateSinkObject.Save(1, &s.maxSACKED) +} + +func (s *SACKScoreboard) afterLoad() {} + +// +checklocksignore +func (s *SACKScoreboard) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.smss) + stateSourceObject.Load(1, &s.maxSACKED) +} + +func (s *segment) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segment" +} + +func (s *segment) StateFields() []string { + return []string{ + "segmentEntry", + "refCnt", + "ep", + "qFlags", + "srcAddr", + "dstAddr", + "netProto", + "nicID", + "data", + "hdr", + "sequenceNumber", + "ackNumber", + "flags", + "window", + "csum", + "csumValid", + "parsedOptions", + "options", + "hasNewSACKInfo", + "rcvdTime", + "xmitTime", + "xmitCount", + "acked", + "dataMemSize", + "lost", + } +} + +func (s *segment) beforeSave() {} + +// +checklocksignore +func (s *segment) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var dataValue buffer.VectorisedView = s.saveData() + stateSinkObject.SaveValue(8, dataValue) + var optionsValue []byte = s.saveOptions() + stateSinkObject.SaveValue(17, optionsValue) + var rcvdTimeValue unixTime = s.saveRcvdTime() + stateSinkObject.SaveValue(19, rcvdTimeValue) + var xmitTimeValue unixTime = s.saveXmitTime() + stateSinkObject.SaveValue(20, xmitTimeValue) + stateSinkObject.Save(0, &s.segmentEntry) + stateSinkObject.Save(1, &s.refCnt) + stateSinkObject.Save(2, &s.ep) + stateSinkObject.Save(3, &s.qFlags) + stateSinkObject.Save(4, &s.srcAddr) + stateSinkObject.Save(5, &s.dstAddr) + stateSinkObject.Save(6, &s.netProto) + stateSinkObject.Save(7, &s.nicID) + stateSinkObject.Save(9, &s.hdr) + stateSinkObject.Save(10, &s.sequenceNumber) + stateSinkObject.Save(11, &s.ackNumber) + stateSinkObject.Save(12, &s.flags) + stateSinkObject.Save(13, &s.window) + stateSinkObject.Save(14, &s.csum) + stateSinkObject.Save(15, &s.csumValid) + stateSinkObject.Save(16, &s.parsedOptions) + stateSinkObject.Save(18, &s.hasNewSACKInfo) + stateSinkObject.Save(21, &s.xmitCount) + stateSinkObject.Save(22, &s.acked) + stateSinkObject.Save(23, &s.dataMemSize) + stateSinkObject.Save(24, &s.lost) +} + +func (s *segment) afterLoad() {} + +// +checklocksignore +func (s *segment) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.segmentEntry) + stateSourceObject.Load(1, &s.refCnt) + stateSourceObject.Load(2, &s.ep) + stateSourceObject.Load(3, &s.qFlags) + stateSourceObject.Load(4, &s.srcAddr) + stateSourceObject.Load(5, &s.dstAddr) + stateSourceObject.Load(6, &s.netProto) + stateSourceObject.Load(7, &s.nicID) + stateSourceObject.Load(9, &s.hdr) + stateSourceObject.Load(10, &s.sequenceNumber) + stateSourceObject.Load(11, &s.ackNumber) + stateSourceObject.Load(12, &s.flags) + stateSourceObject.Load(13, &s.window) + stateSourceObject.Load(14, &s.csum) + stateSourceObject.Load(15, &s.csumValid) + stateSourceObject.Load(16, &s.parsedOptions) + stateSourceObject.Load(18, &s.hasNewSACKInfo) + stateSourceObject.Load(21, &s.xmitCount) + stateSourceObject.Load(22, &s.acked) + stateSourceObject.Load(23, &s.dataMemSize) + stateSourceObject.Load(24, &s.lost) + stateSourceObject.LoadValue(8, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(17, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) }) + stateSourceObject.LoadValue(19, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) }) + stateSourceObject.LoadValue(20, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) }) +} + +func (q *segmentQueue) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentQueue" +} + +func (q *segmentQueue) StateFields() []string { + return []string{ + "list", + "ep", + "frozen", + } +} + +func (q *segmentQueue) beforeSave() {} + +// +checklocksignore +func (q *segmentQueue) StateSave(stateSinkObject state.Sink) { + q.beforeSave() + stateSinkObject.Save(0, &q.list) + stateSinkObject.Save(1, &q.ep) + stateSinkObject.Save(2, &q.frozen) +} + +func (q *segmentQueue) afterLoad() {} + +// +checklocksignore +func (q *segmentQueue) StateLoad(stateSourceObject state.Source) { + stateSourceObject.LoadWait(0, &q.list) + stateSourceObject.Load(1, &q.ep) + stateSourceObject.Load(2, &q.frozen) +} + +func (s *sender) StateTypeName() string { + return "pkg/tcpip/transport/tcp.sender" +} + +func (s *sender) StateFields() []string { + return []string{ + "TCPSenderState", + "ep", + "lr", + "firstRetransmittedSegXmitTime", + "writeNext", + "writeList", + "rtt", + "minRTO", + "maxRTO", + "maxRetries", + "gso", + "state", + "cc", + "rc", + } +} + +func (s *sender) beforeSave() {} + +// +checklocksignore +func (s *sender) StateSave(stateSinkObject state.Sink) { + s.beforeSave() + var firstRetransmittedSegXmitTimeValue unixTime = s.saveFirstRetransmittedSegXmitTime() + stateSinkObject.SaveValue(3, firstRetransmittedSegXmitTimeValue) + stateSinkObject.Save(0, &s.TCPSenderState) + stateSinkObject.Save(1, &s.ep) + stateSinkObject.Save(2, &s.lr) + stateSinkObject.Save(4, &s.writeNext) + stateSinkObject.Save(5, &s.writeList) + stateSinkObject.Save(6, &s.rtt) + stateSinkObject.Save(7, &s.minRTO) + stateSinkObject.Save(8, &s.maxRTO) + stateSinkObject.Save(9, &s.maxRetries) + stateSinkObject.Save(10, &s.gso) + stateSinkObject.Save(11, &s.state) + stateSinkObject.Save(12, &s.cc) + stateSinkObject.Save(13, &s.rc) +} + +// +checklocksignore +func (s *sender) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &s.TCPSenderState) + stateSourceObject.Load(1, &s.ep) + stateSourceObject.Load(2, &s.lr) + stateSourceObject.Load(4, &s.writeNext) + stateSourceObject.Load(5, &s.writeList) + stateSourceObject.Load(6, &s.rtt) + stateSourceObject.Load(7, &s.minRTO) + stateSourceObject.Load(8, &s.maxRTO) + stateSourceObject.Load(9, &s.maxRetries) + stateSourceObject.Load(10, &s.gso) + stateSourceObject.Load(11, &s.state) + stateSourceObject.Load(12, &s.cc) + stateSourceObject.Load(13, &s.rc) + stateSourceObject.LoadValue(3, new(unixTime), func(y interface{}) { s.loadFirstRetransmittedSegXmitTime(y.(unixTime)) }) + stateSourceObject.AfterLoad(s.afterLoad) +} + +func (r *rtt) StateTypeName() string { + return "pkg/tcpip/transport/tcp.rtt" +} + +func (r *rtt) StateFields() []string { + return []string{ + "TCPRTTState", + } +} + +func (r *rtt) beforeSave() {} + +// +checklocksignore +func (r *rtt) StateSave(stateSinkObject state.Sink) { + r.beforeSave() + stateSinkObject.Save(0, &r.TCPRTTState) +} + +func (r *rtt) afterLoad() {} + +// +checklocksignore +func (r *rtt) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &r.TCPRTTState) +} + +func (u *unixTime) StateTypeName() string { + return "pkg/tcpip/transport/tcp.unixTime" +} + +func (u *unixTime) StateFields() []string { + return []string{ + "second", + "nano", + } +} + +func (u *unixTime) beforeSave() {} + +// +checklocksignore +func (u *unixTime) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + stateSinkObject.Save(0, &u.second) + stateSinkObject.Save(1, &u.nano) +} + +func (u *unixTime) afterLoad() {} + +// +checklocksignore +func (u *unixTime) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.second) + stateSourceObject.Load(1, &u.nano) +} + +func (l *endpointList) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpointList" +} + +func (l *endpointList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *endpointList) beforeSave() {} + +// +checklocksignore +func (l *endpointList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *endpointList) afterLoad() {} + +// +checklocksignore +func (l *endpointList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *endpointEntry) StateTypeName() string { + return "pkg/tcpip/transport/tcp.endpointEntry" +} + +func (e *endpointEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *endpointEntry) beforeSave() {} + +// +checklocksignore +func (e *endpointEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *endpointEntry) afterLoad() {} + +// +checklocksignore +func (e *endpointEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func (l *segmentList) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentList" +} + +func (l *segmentList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *segmentList) beforeSave() {} + +// +checklocksignore +func (l *segmentList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *segmentList) afterLoad() {} + +// +checklocksignore +func (l *segmentList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *segmentEntry) StateTypeName() string { + return "pkg/tcpip/transport/tcp.segmentEntry" +} + +func (e *segmentEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *segmentEntry) beforeSave() {} + +// +checklocksignore +func (e *segmentEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *segmentEntry) afterLoad() {} + +// +checklocksignore +func (e *segmentEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*cubicState)(nil)) + state.Register((*SACKInfo)(nil)) + state.Register((*sndQueueInfo)(nil)) + state.Register((*rcvQueueInfo)(nil)) + state.Register((*accepted)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*keepalive)(nil)) + state.Register((*rackControl)(nil)) + state.Register((*receiver)(nil)) + state.Register((*renoState)(nil)) + state.Register((*renoRecovery)(nil)) + state.Register((*sackRecovery)(nil)) + state.Register((*SACKScoreboard)(nil)) + state.Register((*segment)(nil)) + state.Register((*segmentQueue)(nil)) + state.Register((*sender)(nil)) + state.Register((*rtt)(nil)) + state.Register((*unixTime)(nil)) + state.Register((*endpointList)(nil)) + state.Register((*endpointEntry)(nil)) + state.Register((*segmentList)(nil)) + state.Register((*segmentEntry)(nil)) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go deleted file mode 100644 index e7ede7662..000000000 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ /dev/null @@ -1,7929 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "fmt" - "io/ioutil" - "math" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/pkg/waiter" -) - -// endpointTester provides helper functions to test a tcpip.Endpoint. -type endpointTester struct { - ep tcpip.Endpoint -} - -// CheckReadError issues a read to the endpoint and checking for an error. -func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { - t.Helper() - res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) - if got != want { - t.Fatalf("ep.Read = %s, want %s", got, want) - } - if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" { - t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff) - } -} - -// CheckRead issues a read to the endpoint and checking for a success, returning -// the data read. -func (e *endpointTester) CheckRead(t *testing.T) []byte { - t.Helper() - var buf bytes.Buffer - res, err := e.ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("ep.Read = _, %s; want _, nil", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - return buf.Bytes() -} - -// CheckReadFull reads from the endpoint for exactly count bytes. -func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { - t.Helper() - var buf bytes.Buffer - w := tcpip.LimitedWriter{ - W: &buf, - N: int64(count), - } - for w.N != 0 { - _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for receive to be notified. - select { - case <-notifyRead: - case <-time.After(timeout): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } else if err != nil { - t.Fatalf("ep.Read = _, %s; want _, nil", err) - } - } - return buf.Bytes() -} - -const ( - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65535 - - // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an - // IPv4 endpoint when the MTU is set to defaultMTU in the test. - defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize -) - -func TestGiveUpConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventHUp) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - // Close the connection, wait for completion. - ep.Close() - - // Wait for ep to become writable. - <-notifyCh - - // Call Connect again to retreive the handshake failure status - // and stats updates. - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } -} - -// Test for ICMP error handling without completing handshake. -func TestConnectICMPError(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventHUp) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - syn := c.GetPacket() - checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) - - wep := ep.(interface { - StopWork() - ResumeWork() - LastErrorLocked() tcpip.Error - }) - - // Stop the protocol loop, ensure that the ICMP error is processed and - // the last ICMP error is read before the loop is resumed. This sanity - // tests the handshake completion logic on ICMP errors. - wep.StopWork() - - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU) - - for { - if err := wep.LastErrorLocked(); err != nil { - if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { - t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) - } - break - } - time.Sleep(time.Millisecond) - } - - wep.ResumeWork() - - <-notifyCh - - // The stack would have unregistered the endpoint because of the ICMP error. - // Expect a RST for any subsequent packets sent to the endpoint. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, - AckNum: c.IRS + 1, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestConnectIncrementActiveConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) - } -} - -func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.FailedConnectionAttempts.Value() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) - } -} - -func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - want := stats.TCP.FailedConnectionAttempts.Value() + 1 - - { - err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { - t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { - t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) - } -} - -func TestCloseWithoutConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - c.EP.Close() - - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestTCPSegmentsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - // SYN and ACK - want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - if got := stats.TCP.SegmentsSent.Value(); got != want { - t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { - t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) - } -} - -func TestTCPResetsSentIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - want := stats.TCP.SegmentsSent.Value() + 1 - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - // If the AckNum is not the increment of the last sequence number, a RST - // segment is sent back in response. - AckNum: c.IRS + 2, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.GetPacket() - - metricPollFn := func() error { - if got := stats.TCP.ResetsSent.Value(); got != want { - return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) - } - return nil - } - if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { - t.Error(err) - } -} - -// TestTCPResetsSentNoICMP confirms that we don't get an ICMP -// DstUnreachable packet when we try send a packet which is not part -// of an active session. -func TestTCPResetsSentNoICMP(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - stats := c.Stack().Stats() - - // Send a SYN request for a closed port. This should elicit an RST - // but NOT an ICMPv4 DstUnreachable packet. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive whatever comes back. - b := c.GetPacket() - ipHdr := header.IPv4(b) - if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want { - t.Errorf("unexpected protocol, got = %d, want = %d", got, want) - } - - // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. - sent := stats.ICMP.V4.PacketsSent - if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { - t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) - } -} - -// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates -// a RST if an ACK is received on the listening socket for which there is no -// active handshake in progress and we are not using SYN cookies. -func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Lower stackwide TIME_WAIT timeout so that the reservations - // are released instantly on Close. - tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - c.GetPacket() - - // Since an active close was done we need to wait for a little more than - // tcpLingerTimeout for the port reservations to be released and the - // socket to move to a CLOSED state. - time.Sleep(20 * time.Millisecond) - - // Now resend the same ACK, this ACK should generate a RST as there - // should be no endpoint in SYN-RCVD state and we are not using - // syn-cookies yet. The reason we send the same ACK is we need a valid - // cookie(IRS) generated by the netstack without which the ACK will be - // rejected. - c.SendPacket(nil, ackHeaders) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPResetsReceivedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } -} - -func TestTCPResetsDoNotGenerateResets(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - want := stats.TCP.ResetsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - Flags: header.TCPFlagRst, - }) - - if got := stats.TCP.ResetsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) - } - c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) -} - -func TestActiveHandshake(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) -} - -func TestNonBlockingClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } -} - -func TestConnectResetAfterClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPLinger to 3 seconds so that sockets are marked closed - // after 3 second in FIN_WAIT2 state. - tcpLingerTimeout := 3 * time.Second - opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - // Close the endpoint, make sure we get a FIN segment, then acknowledge - // to complete closure of sender, but don't send our own FIN. - ep.Close() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Wait for the ep to give up waiting for a FIN. - time.Sleep(tcpLingerTimeout + 1*time.Second) - - // Now send an ACK and it should trigger a RST as the endpoint should - // not exist anymore. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - for { - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { - // This is a retransmit of the FIN, ignore it. - continue - } - - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence number - // of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - break - } -} - -// TestCurrentConnectedIncrement tests increment of the current -// established and connected counters. -func TestCurrentConnectedIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) - } - gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() - if gotConnected != 1 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) - } - - ep.Close() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) - } - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait for a little more than the TIME-WAIT duration for the socket to - // transition to CLOSED state. - time.Sleep(1200 * time.Millisecond) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -// TestClosingWithEnqueuedSegments tests handling of still enqueued segments -// when the endpoint transitions to StateClose. The in-flight segments would be -// re-enqueued to a any listening endpoint. -func TestClosingWithEnqueuedSegments(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - ep := c.EP - c.EP = nil - - if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } - - // Send a FIN for ESTABLISHED --> CLOSED-WAIT - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Get the ACK for the FIN we sent. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack a few ms to transition the endpoint out of ESTABLISHED - // state. - time.Sleep(10 * time.Millisecond) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { - t.Errorf("unexpected endpoint state: want %d, got %d", want, got) - } - - // Close the application endpoint for CLOSE_WAIT --> LAST_ACK - ep.Close() - - // Get the FIN - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Pause the endpoint`s protocolMainLoop. - ep.(interface{ StopWork() }).StopWork() - - // Enqueue last ACK followed by an ACK matching the endpoint - // - // Send Last ACK for LAST_ACK --> CLOSED - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Send a packet with ACK set, this would generate RST when - // not using SYN cookies as in this test. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss.Add(2), - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Unpause endpoint`s protocolMainLoop. - ep.(interface{ ResumeWork() }).ResumeWork() - - // Wait for the protocolMainLoop to resume and update state. - time.Sleep(10 * time.Millisecond) - - // Expect the endpoint to be closed. - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) - } - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - - // Check if the endpoint was moved to CLOSED and netstack a reset in - // response to the ACK packet that we sent after last-ACK. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst), - ), - ) -} - -func TestSimpleReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when -// creating a new active TCP socket. It should be present in the sent TCP -// SYN segment. -func TestUserSuppliedMSSOnConnect(t *testing.T) { - const mtu = 5000 - - ips := []struct { - name string - createEP func(*context.Context) - connectAddr tcpip.Address - checker func(*testing.T, *context.Context, uint16, int) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - connectAddr: context.TestAddr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(true) - }, - connectAddr: context.TestV6Addr, - checker: func(t *testing.T, c *context.Context, mss uint16, ws int) { - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws}))) - }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, - }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, - }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - ip.createEP(c) - - // Set the MSS socket option. - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } - - // Get expected window size. - rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() - ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) - - connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - { - err := c.EP.Connect(connectAddr) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) - } - } - - // Receive SYN packet with our user supplied MSS. - ip.checker(t, c, test.expMSS, ws) - }) - } - }) - } -} - -// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used -// when completing the handshake for a new TCP connection from a TCP -// listening socket. It should be present in the sent TCP SYN-ACK segment. -func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const mtu = 5000 - - ips := []struct { - name string - createEP func(*context.Context) - sendPkt func(*context.Context, *context.Headers) - checker func(*testing.T, *context.Context, uint16, uint16) - maxMSS uint16 - }{ - { - name: "IPv4", - createEP: func(c *context.Context) { - c.Create(-1) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendPacket(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) - }, - maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, - }, - { - name: "IPv6", - createEP: func(c *context.Context) { - c.CreateV6Endpoint(false) - }, - sendPkt: func(c *context.Context, h *context.Headers) { - c.SendV6Packet(nil, h) - }, - checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) { - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1}))) - }, - maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize, - }, - } - - for _, ip := range ips { - t.Run(ip.name, func(t *testing.T) { - tests := []struct { - name string - setMSS uint16 - expMSS uint16 - }{ - { - name: "EqualToMaxMSS", - setMSS: ip.maxMSS, - expMSS: ip.maxMSS, - }, - { - name: "LessThanMaxMSS", - setMSS: ip.maxMSS - 1, - expMSS: ip.maxMSS - 1, - }, - { - name: "GreaterThanMaxMSS", - setMSS: ip.maxMSS + 1, - expMSS: ip.maxMSS, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, mtu) - defer c.Cleanup() - - ip.createEP(c) - - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { - t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) - } - - bindAddr := tcpip.FullAddress{Port: context.StackPort} - if err := c.EP.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s:", bindAddr, err) - } - - backlog := 5 - // Keep the number of client requests twice to the backlog - // such that half of the connections do not use syncookies - // and the other half does. - clientConnects := backlog * 2 - - if err := c.EP.Listen(backlog); err != nil { - t.Fatalf("Listen(%d): %s:", backlog, err) - } - - for i := 0; i < clientConnects; i++ { - // Send a SYN requests. - iss := seqnum.Value(i) - srcPort := context.TestPort + uint16(i) - ip.sendPkt(c, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - }) - - // Receive the SYN-ACK reply. - ip.checker(t, c, srcPort, test.expMSS) - } - }) - } - }) - } -} -func TestSendRstOnListenerRxSynAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -func TestSendRstOnListenerRxSynAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv4. -func TestTCPAckBeforeAcceptV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, -// peers can send data and expect a response within a reasonable ammount of time -// without calling Accept on the listening endpoint first. -// -// This test uses IPv6. -func TestTCPAckBeforeAcceptV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) - - // Send data before accepting the connection. - c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestSendRstOnListenerRxAckV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -func TestSendRstOnListenerRxAckV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true /* v6Only */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - c.SendV6Packet(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagFin | header.TCPFlagAck, - SeqNum: 100, - AckNum: 200, - }) - - checker.IPv6(t, c.GetV6Packet(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - checker.TCPSeqNum(200))) -} - -// TestListenShutdown tests for the listening endpoint replying with RST -// on read shutdown. -func TestListenShutdown(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatal("Shutdown failed:", err) - } - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: 100, - AckNum: 200, - }) - - // Expect the listening endpoint to reset the connection. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) -} - -var _ waiter.EntryCallback = (callback)(nil) - -type callback func(*waiter.Entry, waiter.EventMask) - -func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { - cb(entry, mask) -} - -func TestListenerReadinessOnEvent(t *testing.T) { - s := stack.New(stack.Options{ - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - { - ep := loopback.New() - if testing.Verbose() { - ep = sniffer.New(ep) - } - const id = 1 - if err := s.CreateNIC(id, ep); err != nil { - t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) - } - if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { - t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: header.IPv4EmptySubnet, NIC: id}, - }) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { - t.Fatalf("Bind(%s): %s", context.StackAddr, err) - } - const backlog = 1 - if err := ep.Listen(backlog); err != nil { - t.Fatalf("Listen(%d): %s", backlog, err) - } - - address, err := ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress(): %s", err) - } - - conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) - } - defer conn.Close() - - events := make(chan waiter.EventMask) - // Scope `entry` to allow a binding of the same name below. - { - entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) { - events <- ep.Readiness(mask) - })} - wq.EventRegister(&entry, waiter.EventIn) - defer wq.EventUnregister(&entry) - } - - entry, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&entry, waiter.EventOut) - defer wq.EventUnregister(&entry) - - switch err := conn.Connect(address).(type) { - case *tcpip.ErrConnectStarted: - default: - t.Fatalf("Connect(%#v): %v", address, err) - } - - // Read at least one event. - got := <-events - for { - select { - case event := <-events: - got |= event - continue - case <-ch: - if want := waiter.ReadableEvents; got != want { - t.Errorf("observed events = %b, want %b", got, want) - } - } - break - } -} - -// TestListenCloseWhileConnect tests for the listening endpoint to -// drain the accept-queue when closed. This should reset all of the -// pending connections that are waiting to be accepted. -func TestListenCloseWhileConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1 /* epRcvBuf */) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(1 /* backlog */); err != nil { - t.Fatal("Listen failed:", err) - } - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - // Wait for the new endpoint created because of handshake to be delivered - // to the listening endpoint's accept queue. - <-notifyCh - - // Close the listening endpoint. - c.EP.Close() - - // Expect the listening endpoint to reset the connection. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - )) -} - -func TestTOSV4(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { - t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) - } - - v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) - } - - if v != tos { - t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) - } - - testV4Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestTrafficClassV6(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(false) - - const tos = 0xC0 - if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { - t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) - } - - v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) - } - - if v != tos { - t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) - } - - // Test the connection request. - testV6Connect(t, c, checker.TOS(tos, 0)) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetV6Packet() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv6(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - checker.TOS(tos, 0), - ) - - if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } -} - -func TestConnectBindToDevice(t *testing.T) { - for _, test := range []struct { - name string - device tcpip.NICID - want tcp.EndpointState - }{ - {"RightDevice", 1, tcp.StateEstablished}, - {"WrongDevice", 2, tcp.StateSynSent}, - {"AnyDevice", 0, tcp.StateEstablished}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) - } - // Start connection attempt. - waitEntry, _ := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.WritableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - - c.GetPacket() - if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { - t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) - } - }) - } -} - -func TestSynSent(t *testing.T) { - for _, test := range []struct { - name string - reset bool - }{ - {"RstOnSynSent", true}, - {"CloseOnSynSent", false}, - } { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create an endpoint, don't handshake because we want to interfere with the - // handshake process. - c.Create(-1) - - // Start connection attempt. - waitEntry, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - err := c.EP.Connect(addr) - if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { - t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - if test.reset { - // Send a packet with a proper ACK and a RST flag to cause the socket - // to error and close out. - iss := seqnum.Value(context.TestInitialSequenceNumber) - rcvWnd := seqnum.Size(30000) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagRst | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: nil, - }) - } else { - c.EP.Close() - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatal("timed out waiting for packet to arrive") - } - - ept := endpointTester{c.EP} - if test.reset { - ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) - } else { - ept.CheckReadError(t, &tcpip.ErrAborted{}) - } - - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } - - // Due to the RST the endpoint should be in an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Fatalf("got State() = %s, want %s", got, want) - } - }) - } -} - -func TestOutOfOrderReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send second half of data first, with seqnum 3 ahead of expected. - data := []byte{1, 2, 3, 4, 5, 6} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(3), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that we get an ACK specifying which seqnum is expected. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Wait 200ms and check that no data has been received. - time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send the first 3 bytes now. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive data. - read := ept.CheckReadFull(t, 6, ch, 5*time.Second) - - // Check that we received the data in proper order. - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that the whole data is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestOutOfOrderFlood(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rcvBufSz := math.MaxUint16 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send 100 packets before the actual one that is expected. - data := []byte{1, 2, 3, 4, 5, 6} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < 100; i++ { - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(6), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Send packet with seqnum as initial + 3. It must be discarded because the - // out-of-order buffer was filled by the previous packets. - c.SendPacket(data[3:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(3), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now send the expected packet with initial sequence number. - c.SendPacket(data[:3], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that only packet with initial sequence number is acknowledged. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+3), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestRstOnCloseWithUnreadData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Now that we know we have unread data, let's just close the connection - // and verify that netstack sends an RST rather than a FIN. - c.EP.Close() - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // This final ACK should be ignored because an ACK on a reset doesn't mean - // anything. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(len(data))), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(3 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that ACK is received, this happens regardless of the read. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Cause a FIN to be generated. - c.EP.Shutdown(tcpip.ShutdownWrite) - - // Make sure we get the FIN but DON't ACK IT. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - checker.TCPSeqNum(uint32(c.IRS)+1), - )) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Cause a RST to be generated by closing the read end now since we have - // unread data. - c.EP.Shutdown(tcpip.ShutdownRead) - - // Make sure we get the RST - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // RST is always generated with sndNxt which if the FIN - // has been sent will be 1 higher than the sequence - // number of the FIN itself. - checker.TCPSeqNum(uint32(c.IRS)+2), - )) - // The RST puts the endpoint into an error state. - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // The ACK to the FIN should now be rejected since the connection has been - // closed by a RST. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(len(data))), - AckNum: c.IRS.Add(seqnum.Size(2)), - RcvWnd: 30000, - }) -} - -func TestShutdownRead(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { - t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) - } -} - -func TestFullWindowReceive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBufSz = 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies - // the provided buffer value by tcp.SegOverheadFactor to calculate the actual - // receive buffer size. - data := make([]byte, tcp.SegOverheadFactor*rcvBufSz) - for i := range data { - data[i] = byte(i % 255) - } - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that data is acknowledged, and window goes to zero. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) - - // Receive data and check it. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - var want uint64 = 1 - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { - t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) - } - - // Check that we get an ACK for the newly non-zero window. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(len(data))), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(10), - ), - ) -} - -// Test the stack receive window advertisement on receiving segments smaller than -// segment overhead. It tests for the right edge of the window to not grow when -// the endpoint is not being read from. -func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize, - Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), - } - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - - c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Bump up the receive buffer size such that, when the receive window grows, - // the scaled window exceeds maxUint16. - c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true) - - // Keep the payload size < segment overhead and such that it is a multiple - // of the window scaled value. This enables the test to perform equality - // checks on the incoming receive window. - payloadSize := 1 << c.RcvdWindowScale - if payloadSize >= tcp.SegSize { - t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize) - } - payload := generateRandomPayload(t, payloadSize) - payloadLen := seqnum.Size(len(payload)) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - - // Send payload to the endpoint and return the advertised receive window - // from the endpoint. - getIncomingRcvWnd := func() uint32 { - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss, - AckNum: c.IRS.Add(1), - Flags: header.TCPFlagAck, - RcvWnd: 30000, - }) - iss = iss.Add(payloadLen) - - pkt := c.GetPacket() - return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale - } - - // Read the advertised receive window with the ACK for payload. - rcvWnd := getIncomingRcvWnd() - - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } - - // Read the data so that the subsequent ACK from the endpoint - // grows the right edge of the window. - var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { - t.Fatalf("c.EP.Read: %s", err) - } - - // Check if we have received max uint16 as our advertised - // scaled window now after a read above. - maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) - if got, want := getIncomingRcvWnd(), maxRcv; got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } - - // Check if the subsequent ACK to our send has not grown the right edge of - // the window. - if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { - t.Fatalf("got incomingRcvwnd %d want %d", got, want) - } -} - -func TestNoWindowShrinking(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Start off with a certain receive buffer then cut it in half and verify that - // the right edge of the window does not shrink. - // NOTE: Netstack doubles the value specified here. - rcvBufSize := 65536 - // Enable window scaling with a scale of zero from our end. - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send a 1 byte payload so that we can record the current receive window. - // Send a payload of half the size of rcvBufSize. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - payload := []byte{1} - c.SendPacket(payload, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Read the 1 byte payload we just sent. - if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { - t.Fatalf("got data: %v, want: %v", got, want) - } - - // Verify that the ACK does not shrink the window. - pkt := c.GetPacket() - iss = iss.Add(1) - checker.IPv4(t, pkt, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Stash the initial window. - initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) - // Now shrink the receive buffer to half its original size. - c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true) - - data := generateRandomPayload(t, rcvBufSize) - // Send a payload of half the size of rcvBufSize. - c.SendPacket(data[:rcvBufSize/2], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - iss = iss.Add(seqnum.Size(rcvBufSize / 2)) - - // Verify that the ACK does not shrink the window. - pkt = c.GetPacket() - checker.IPv4(t, pkt, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale - newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd)) - if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) { - t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq) - } - - // Send another payload of half the size of rcvBufSize. This should fill up the - // socket receive buffer and we should see a zero window. - c.SendPacket(data[rcvBufSize/2:], &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - iss = iss.Add(seqnum.Size(rcvBufSize / 2)) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(0), - ), - ) - - // Receive data and check it. - read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) - if !bytes.Equal(data, read) { - t.Fatalf("got data = %v, want = %v", read, data) - } - - // Check that we get an ACK for the newly non-zero window, which is the new - // receive buffer size we set after the connection was established. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale), - ), - ) -} - -func TestSimpleSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestZeroWindowSend(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check if we got a zero-window probe. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Open up the window. Data should be received now. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Fatalf("got data = %v, want = %v", p, data) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), - RcvWnd: 30000, - }) -} - -func TestScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestNonScaledWindowConnect(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Connect(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the window size greater than the maximum non-scaled window. - c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3) - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is used when the peer - // does advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*3, true) - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2 - c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0x5fff, - // that is, that it is scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestNonScaledWindowAccept(t *testing.T) { - // This test ensures that window scaling is not used when the peer - // doesn't advertise it and connection is established with Accept(). - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*3, true) - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN - // should not carry the window scaling option. - c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received, and that advertised window is 0xffff, - // that is, that it's not scaled. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) -} - -func TestZeroScaledWindowReceive(t *testing.T) { - // This test ensures that the endpoint sends a non-zero window size - // advertisement when the scaled window transitions from 0 to non-zero, - // but the actual window (not scaled) hasn't gotten to zero. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set the buffer size such that a window scale of 5 will be used. - const bufSz = 65535 * 10 - const ws = uint32(5) - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{ - header.TCPOptionWS, 3, 0, header.TCPOptionNOP, - }) - - // Write chunks of 50000 bytes. - remain := 0 - sent := 0 - data := make([]byte, 50000) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Keep writing till the window drops below len(data). - for { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Don't reduce window to zero here. - if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) { - remain = wnd << ws - break - } - } - - // Make the window non-zero, but the scaled window zero. - for remain >= 16 { - data = data[:remain-15] - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Since the receive buffer is split between window advertisement and - // application data buffer the window does not always reflect the space - // available and actual space available can be a bit more than what is - // advertised in the window. - wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) - if wnd == 0 { - break - } - remain = wnd << ws - } - - // Read at least 2MSS of data. An ack should be sent in response to that. - // Since buffer space is now split in half between window and application - // data we need to read more than 1 MSS(65536) of data for a non-zero window - // update to be sent. For 1MSS worth of window to be available we need to - // read at least 128KB. Since our segments above were 50KB each it means - // we need to read at 3 packets. - w := tcpip.LimitedWriter{ - W: ioutil.Discard, - N: defaultMTU * 2, - } - for w.N != 0 { - res, err := c.EP.Read(&w, tcpip.ReadOptions{}) - t.Logf("err=%v res=%#v", err, res) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestSegmentMerging(t *testing.T) { - tests := []struct { - name string - stop func(tcpip.Endpoint) - resume func(tcpip.Endpoint) - }{ - { - "stop work", - func(ep tcpip.Endpoint) { - ep.(interface{ StopWork() }).StopWork() - }, - func(ep tcpip.Endpoint) { - ep.(interface{ ResumeWork() }).ResumeWork() - }, - }, - { - "cork", - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(true) - }, - func(ep tcpip.Endpoint) { - ep.SocketOptions().SetCorkOption(false) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Send tcp.InitialCwnd number of segments to fill up - // InitialWindow but don't ACK. That should prevent - // anymore packets from going out. - var r bytes.Reader - for i := 0; i < tcp.InitialCwnd; i++ { - r.Reset([]byte{0}) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - // Now send the segments that should get merged as the congestion - // window is full and we won't be able to send any more packets. - var allData []byte - for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - // Check that we get tcp.InitialCwnd packets. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < tcp.InitialCwnd; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(header.TCPMinimumSize+1), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. - RcvWnd: 30000, - }) - - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(allData)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+11), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { - t.Fatalf("got data = %v, want = %v", got, allData) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), - RcvWnd: 30000, - }) - }) - } -} - -func TestDelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - c.EP.SocketOptions().SetDelayOption(true) - - var allData []byte - for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { - allData = append(allData, data...) - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for _, want := range [][]byte{allData[:1], allData[1:]} { - // Check that data is received. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(want)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { - t.Fatalf("got data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(want))) - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) - } -} - -func TestUndelay(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - c.EP.SocketOptions().SetDelayOption(true) - - allData := [][]byte{{0}, {1, 2, 3}} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Check that data is received. - first := c.GetPacket() - checker.IPv4(t, first, - checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { - t.Fatalf("got first packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[0]))) - - // Check that we don't get the second packet yet. - c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - - c.EP.SocketOptions().SetDelayOption(false) - - // Check that data is received. - second := c.GetPacket() - checker.IPv4(t, second, - checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { - t.Fatalf("got second packet's data = %v, want = %v", got, want) - } - - seq = seq.Add(seqnum.Size(len(allData[1]))) - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) -} - -func TestMSSNotDelayed(t *testing.T) { - tests := []struct { - name string - fn func(tcpip.Endpoint) - }{ - {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const maxPayload = 100 - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - test.fn(c.EP) - - allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} - for i, data := range allData { - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write #%d failed: %s", i+1, err) - } - } - - seq := c.IRS.Add(1) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i, data := range allData { - // Check that data is received. - packet := c.GetPacket() - checker.IPv4(t, packet, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(seq)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { - t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) - } - - seq = seq.Add(seqnum.Size(len(data))) - } - - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seq, - RcvWnd: 30000, - }) - }) - } -} - -func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { - payloadMultiplier := 10 - dataLen := payloadMultiplier * maxPayload - data := make([]byte, dataLen) - for i := range data { - data[i] = byte(i) - } - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received in chunks. - bytesReceived := 0 - numPackets := 0 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for bytesReceived != dataLen { - b := c.GetPacket() - numPackets++ - tcpHdr := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcpHdr.Payload()) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { - t.Fatalf("got data = %v, want = %v", p, pdata) - } - bytesReceived += payloadLen - var options []byte - if c.TimeStampEnabled { - // If timestamp option is enabled, echo back the timestamp and increment - // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcpHdr.ParsedOptions() - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) - options = tsOpt[:] - } - // Acknowledge the data. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options, - }) - } - if numPackets == 1 { - t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") - } -} - -func TestSendGreaterThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSetTTL(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := context.New(t, 65535) - defer c.Cleanup() - - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { - t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) - } - - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - // Receive SYN packet. - b := c.GetPacket() - - checker.IPv4(t, b, checker.TTL(wantTTL)) - }) - } -} - -func TestActiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - c := context.New(t, 65535) - defer c.Cleanup() - - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestPassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 536 - const mtu = 2000 - c := context.New(t, mtu) - defer c.Cleanup() - - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestForwarderSendMSSLessThanMTU(t *testing.T) { - const maxPayload = 100 - const mtu = 1200 - c := context.New(t, mtu) - defer c.Cleanup() - - s := c.Stack() - ch := make(chan tcpip.Error, 1) - f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err tcpip.Error - c.EP, err = r.CreateEndpoint(&c.WQ) - ch <- err - }) - s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) - - // Do 3-way handshake. - c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) - - // Wait for connection to be available. - select { - case err := <-ch: - if err != nil { - t.Fatalf("Error creating endpoint: %s", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("Timed out waiting for connection") - } - - // Check that data gets properly segmented. - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSynOptionsOnActiveConnect(t *testing.T) { - const mtu = 1400 - c := context.New(t, mtu) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Set the buffer size to a deterministic size so that we can check the - // window scaling option. - const rcvBufferSize = 0x20000 - const wndScale = 3 - c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) - - // Start connection attempt. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.WritableEvents) - defer c.WQ.EventUnregister(&we) - - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - // Receive SYN packet. - b := c.GetPacket() - mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - // Wait for retransmit. - time.Sleep(1 * time.Second) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcpHdr.SourcePort()), - checker.TCPSeqNum(tcpHdr.SequenceNumber()), - checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), - ), - ) - - // Send SYN-ACK. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-ch: - if err := c.EP.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) - } - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for connection") - } -} - -func TestCloseListener(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Close the listener and measure how long it takes. - t0 := time.Now() - ep.Close() - if diff := time.Now().Sub(t0); diff > 3*time.Second { - t.Fatalf("Took too long to close: %s", diff) - } -} - -func TestReceiveOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - // Send RST segment. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Try to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - -loop: - for { - switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { - case *tcpip.ErrWouldBlock: - select { - case <-ch: - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) - } - break loop - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") - } - case *tcpip.ErrConnectionReset: - break loop - default: - t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) - } - } - - if tcp.EndpointState(c.EP.State()) != tcp.StateError { - t.Fatalf("got EP state is not StateError") - } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestSendOnResetConnection(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Send RST segment. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Wait for the RST to be received. - time.Sleep(1 * time.Second) - - // Try to write. - var r bytes.Reader - r.Reset(make([]byte, 10)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) - } -} - -// TestMaxRetransmitsTimeout tests if the connection is timed out after -// a segment has been retransmitted MaxRetries times. -func TestMaxRetransmitsTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const numRetries = 2 - opt := tcpip.TCPMaxRetriesOption(numRetries) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Expect first transmit and MaxRetries retransmits. - for i := 0; i < numRetries+1; i++ { - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - } - // Wait for the connection to timeout after MaxRetries retransmits. - initRTO := 1 * time.Second - select { - case <-notifyCh: - case <-time.After((2 << numRetries) * initRTO): - t.Fatalf("connection still alive after maximum retransmits.\n") - } - - // Send an ACK and expect a RST as the connection would have been closed. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -// TestMaxRTO tests if the retransmit interval caps to MaxRTO. -func TestMaxRTO(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rto := 1 * time.Second - opt := tcpip.TCPMaxRTOOption(rto) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - var r bytes.Reader - r.Reset(make([]byte, 1)) - _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if err != nil { - t.Fatalf("Write failed: %s", err) - } - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - const numRetransmits = 2 - for i := 0; i < numRetransmits; i++ { - start := time.Now() - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { - t.Errorf("Retransmit interval not capped to MaxRTO.\n") - } - } -} - -// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is -// unique on retransmits. -func TestRetransmitIPv4IDUniqueness(t *testing.T) { - for _, tc := range []struct { - name string - size int - }{ - {"1Byte", 1}, - {"512Bytes", 512}, - } { - t.Run(tc.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - - // Disabling PMTU discovery causes all packets sent from this socket to - // have DF=0. This needs to be done because the IPv4 ID uniqueness - // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 - // Section 4, and datagrams with DF=0 are non-atomic. - if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { - t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) - } - - var r bytes.Reader - r.Reset(make([]byte, tc.size)) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.FragmentFlags(0), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} - // Expect two retransmitted packets, and that all packets received have - // unique IPv4 ID values. - for i := 0; i <= 2; i++ { - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.FragmentFlags(0), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - id := header.IPv4(pkt).ID() - if _, exists := idSet[id]; exists { - t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) - } - idSet[id] = struct{}{} - } - }) - } -} - -func TestFinImmediately(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Shutdown immediately, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Don't acknowledge yet. We should get a retransmit of the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithNoPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write something out, and have it acknowledged. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Shutdown, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Ack and send FIN as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that the stack acks the FIN. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingDataCwndFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write enough segments to fill the congestion window before ACK'ing - // any of them. - view := make([]byte, 10) - var r bytes.Reader - for i := tcp.InitialCwnd; i > 0; i-- { - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := tcp.InitialCwnd; i > 0; i-- { - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - } - - // Shutdown the connection, check that the FIN segment isn't sent - // because the congestion window doesn't allow it. Wait until a - // retransmit is received. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Send the ACK that will allow the FIN to be sent as well. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPendingData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send a FIN that acknowledges everything. Get an ACK back. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestFinWithPartialAck(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Write something out, and acknowledge it to get cwnd to 2. Also send - // FIN from the test side. - view := make([]byte, 10) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Check that we get an ACK for the fin. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Write new data, but don't acknowledge it. - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - next += uint32(len(view)) - - // Shutdown the connection, check that we do get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - next++ - - // Send an ACK for the data, but not for the FIN yet. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(1), - AckNum: seqnum.Value(next - 1), - RcvWnd: 30000, - }) - - // Check that we don't get a retransmit of the FIN. - c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) - - // Ack the FIN. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss.Add(1), - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) -} - -func TestUpdateListenBacklog(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create listener. - var wq waiter.Queue - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Update the backlog with another Listen() on the same endpoint. - if err := ep.Listen(20); err != nil { - t.Fatalf("Listen failed to update backlog: %s", err) - } - - ep.Close() -} - -func scaledSendWindow(t *testing.T, scale uint8) { - // This test ensures that the endpoint is using the right scaling by - // sending a buffer that is larger than the window size, and ensuring - // that the endpoint doesn't send more than allowed. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - header.TCPOptionWS, 3, scale, header.TCPOptionNOP, - }) - - // Open up the window with a scaled value. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 1, - }) - - // Send some data. Check that it's capped by the window size. - view := make([]byte, 65535) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that only data that fits in the scaled window is sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen((1<<scale)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Reset the connection to free resources. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagRst, - SeqNum: iss, - }) -} - -func TestScaledSendWindow(t *testing.T) { - for scale := uint8(0); scale <= 14; scale++ { - scaledSendWindow(t, scale) - } -} - -func TestReceivedValidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ValidSegmentsReceived.Value() + 1 - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { - t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) - } - // Ensure there were no errors during handshake. If these stats have - // incremented, then the connection should not have been established. - if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { - t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) - } -} - -func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.InvalidSegmentsReceived.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - vv := c.BuildSegment(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] - tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 - - c.SendSegment(vv) - - if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { - t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } -} - -func TestReceivedIncorrectChecksumIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - stats := c.Stack().Stats() - want := stats.TCP.ChecksumErrors.Value() + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] - // Overwrite a byte in the payload which should cause checksum - // verification to fail. - tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 - - c.SendSegment(vv) - - if got := stats.TCP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) - } -} - -func TestReceivedSegmentQueuing(t *testing.T) { - // This test sends 200 segments containing a few bytes each to an - // endpoint and checks that they're all received and acknowledged by - // the endpoint, that is, that none of the segments are dropped by - // internal queues. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - // Send 200 segments. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - data := []byte{1, 2, 3} - for i := 0; i < 200; i++ { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(i * len(data))), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - } - - // Receive ACKs for all segments. - last := iss.Add(seqnum.Size(200 * len(data))) - for { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - tcpHdr := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcpHdr.AckNumber()) - if ack == last { - break - } - - if last.LessThan(ack) { - t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) - } - } -} - -func TestReadAfterClosedState(t *testing.T) { - // This test ensures that calling Read() or Peek() after the endpoint - // has transitioned to closedState still works if there is pending - // data. To transition to stateClosed without calling Close(), we must - // shutdown the send path and the peer must send its own FIN. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed - // after 1 second in TIME_WAIT state. - tcpTimeWaitTimeout := 1 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Shutdown immediately for write, check that we get a FIN. - if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), - ), - ) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Send some data and acknowledge the FIN. - data := []byte{1, 2, 3} - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss, - AckNum: c.IRS.Add(2), - RcvWnd: 30000, - }) - - // Check that ACK is received. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+2), - checker.TCPAckNum(uint32(iss)+uint32(len(data))+1), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Give the stack the chance to transition to closed state from - // TIME_WAIT. - time.Sleep(tcpTimeWaitTimeout * 2) - - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Check that peek works. - var peekBuf bytes.Buffer - res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) - if err != nil { - t.Fatalf("Peek failed: %s", err) - } - - if got, want := res.Count, len(data); got != want { - t.Fatalf("res.Count = %d, want %d", got, want) - } - if !bytes.Equal(data, peekBuf.Bytes()) { - t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) - } - - // Receive data. - v := ept.CheckRead(t) - if !bytes.Equal(data, v) { - t.Fatalf("got data = %v, want = %v", v, data) - } - - // Now that we drained the queue, check that functions fail with the - // right error code. - ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) - var buf bytes.Buffer - { - _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { - t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) - } - } -} - -func TestReusePort(t *testing.T) { - // This test ensures that ports are immediately available for reuse - // after Close on the endpoints using them returns. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // First case, just an endpoint that was bound. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - c.EP.Close() - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() - - // Second case, an endpoint that was bound and is connecting.. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - { - err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) - } - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - c.EP.Close() - - // Third case, an endpoint that was bound and is listening. - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - c.EP.Close() - - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - c.EP.SocketOptions().SetReuseAddress(true) - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } -} - -func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - s := ep.SocketOptions().GetReceiveBufferSize() - if int(s) != v { - t.Fatalf("got receive buffer size = %d, want = %d", s, v) - } -} - -func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { - t.Helper() - - if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v { - t.Fatalf("got send buffer size = %d, want = %d", s, v) - } -} - -func TestDefaultBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer func() { - if ep != nil { - ep.Close() - } - }() - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default send buffer size. - { - opt := tcpip.TCPSendBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) - - // Change the default receive buffer size. - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30, - } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - ep.Close() - ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) -} - -func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - // Change the min/max values for send/receive - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - { - opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Set values below the min/2. - ep.SocketOptions().SetReceiveBufferSize(99, true) - checkRecvBufferSize(t, ep, 200) - - ep.SocketOptions().SetSendBufferSize(149, true) - - checkSendBufferSize(t, ep, 300) - - // Set values above the max. - ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true) - // Values above max are capped at max and then doubled. - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - - ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) - // Values above max are capped at max and then doubled. - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) - - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - if err := s.CreateNIC(321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %s", err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError tcpip.Error - getBindToDevice int32 - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := int32(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) - } - } - bindToDevice := ep.SocketOptions().GetBindToDevice() - if bindToDevice != testAction.getBindToDevice { - t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) - } - }) - } -} - -func makeStack() (*stack.Stack, tcpip.Error) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - id := loopback.New() - if testing.Verbose() { - id = sniffer.New(id) - } - - if err := s.CreateNIC(1, id); err != nil { - return nil, err - } - - for _, ct := range []struct { - number tcpip.NetworkProtocolNumber - address tcpip.Address - }{ - {ipv4.ProtocolNumber, context.StackAddr}, - {ipv6.ProtocolNumber, context.StackV6Addr}, - } { - if err := s.AddAddress(1, ct.number, ct.address); err != nil { - return nil, err - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return s, nil -} - -func TestSelfConnect(t *testing.T) { - // This test ensures that intentional self-connects work. In particular, - // it checks that if an endpoint binds to say 127.0.0.1:1000 then - // connects to 127.0.0.1:1000, then it will be connected to itself, and - // is able to send and receive data through the same endpoint. - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Register for notification, then start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.WritableEvents) - defer wq.EventUnregister(&waitEntry) - - { - err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { - t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) - } - } - - <-notifyCh - if err := ep.LastError(); err != nil { - t.Fatalf("Connect failed: %s", err) - } - - // Write something. - data := []byte{1, 2, 3} - var r bytes.Reader - r.Reset(data) - if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Read back what was written. - wq.EventUnregister(&waitEntry) - wq.EventRegister(&waitEntry, waiter.ReadableEvents) - ept := endpointTester{ep} - rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) - - if !bytes.Equal(data, rd) { - t.Fatalf("got data = %v, want = %v", rd, data) - } -} - -func TestConnectAvoidsBoundPorts(t *testing.T) { - addressTypes := func(t *testing.T, network string) []string { - switch network { - case "ipv4": - return []string{"v4"} - case "ipv6": - return []string{"v6"} - case "dual": - return []string{"v6", "mapped"} - default: - t.Fatalf("unknown network: '%s'", network) - } - - panic("unreachable") - } - - address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { - switch addressType { - case "v4": - if isAny { - return "" - } - return context.StackAddr - case "v6": - if isAny { - return "" - } - return context.StackV6Addr - case "mapped": - if isAny { - return context.V4MappedWildcardAddr - } - return context.StackV4MappedAddr - default: - t.Fatalf("unknown address type: '%s'", addressType) - } - - panic("unreachable") - } - // This test ensures that Endpoint.Connect doesn't select already-bound ports. - networks := []string{"ipv4", "ipv6", "dual"} - for _, exhaustedNetwork := range networks { - t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { - for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { - t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { - for _, isAny := range []bool{false, true} { - t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { - for _, candidateNetwork := range networks { - t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { - for _, candidateAddressType := range addressTypes(t, candidateNetwork) { - t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { - s, err := makeStack() - if err != nil { - t.Fatal(err) - } - - var wq waiter.Queue - var eps []tcpip.Endpoint - defer func() { - for _, ep := range eps { - ep.Close() - } - }() - makeEP := func(network string) tcpip.Endpoint { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch network { - case "ipv4": - networkProtocolNumber = ipv4.ProtocolNumber - case "ipv6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatalf("unknown network: '%s'", network) - } - ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - eps = append(eps, ep) - switch network { - case "ipv4": - case "ipv6": - ep.SocketOptions().SetV6Only(true) - case "dual": - ep.SocketOptions().SetV6Only(false) - default: - t.Fatalf("unknown network: '%s'", network) - } - return ep - } - - var v4reserved, v6reserved bool - switch exhaustedAddressType { - case "v4", "mapped": - v4reserved = true - case "v6": - v6reserved = true - // Dual stack sockets bound to v6 any reserve on v4 as - // well. - if isAny { - switch exhaustedNetwork { - case "ipv6": - case "dual": - v4reserved = true - default: - t.Fatalf("unknown address type: '%s'", exhaustedNetwork) - } - } - default: - t.Fatalf("unknown address type: '%s'", exhaustedAddressType) - } - var collides bool - switch candidateAddressType { - case "v4", "mapped": - collides = v4reserved - case "v6": - collides = v6reserved - default: - t.Fatalf("unknown address type: '%s'", candidateAddressType) - } - - const ( - start = 16000 - end = 16050 - ) - if err := s.SetPortRange(start, end); err != nil { - t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) - } - for i := start; i <= end; i++ { - if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { - t.Fatalf("Bind(%d) failed: %s", i, err) - } - } - var want tcpip.Error = &tcpip.ErrConnectStarted{} - if collides { - want = &tcpip.ErrNoPortAvailable{} - } - if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { - t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) - } - }) - } - }) - } - }) - } - }) - } - }) - } -} - -func TestPathMTUDiscovery(t *testing.T) { - // This test verifies the stack retransmits packets after it receives an - // ICMP packet indicating that the path MTU has been exceeded. - c := context.New(t, 1500) - defer c.Cleanup() - - // Create new connection with MSS of 1460. - const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{ - header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), - }) - - // Send 3200 bytes of data. - const writeSize = 3200 - data := make([]byte, writeSize) - for i := range data { - data[i] = byte(i) - } - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { - var ret []byte - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i, size := range sizes { - p := c.GetPacket() - if i == which { - ret = p - } - checker.IPv4(t, p, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(seqNum), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - seqNum += uint32(size) - } - return ret - } - - // Receive three packets. - sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} - first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) - - // Send "packet too big" messages back to netstack. - const newMTU = 1200 - const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize - mtu := []byte{0, 0, newMTU / 256, newMTU % 256} - c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) - - // See retransmitted packets. None exceeding the new max. - sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} - receivePackets(c, sizes, -1, uint32(c.IRS)+1) -} - -func TestTCPEndpointProbe(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - invoked := make(chan struct{}) - c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { - // Validate that the endpoint ID is what we expect. - // - // We don't do an extensive validation of every field but a - // basic sanity test. - if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("got LocalAddress: %q, want: %q", got, want) - } - if got, want := state.ID.LocalPort, c.Port; got != want { - t.Fatalf("got LocalPort: %d, want: %d", got, want) - } - if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("got RemoteAddress: %q, want: %q", got, want) - } - if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("got RemotePort: %d, want: %d", got, want) - } - - invoked <- struct{}{} - }) - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - data := []byte{1, 2, 3} - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - - select { - case <-invoked: - case <-time.After(100 * time.Millisecond): - t.Fatalf("TCP Probe function was not called") - } -} - -func TestStackSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - var oldCC tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) - } - - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) - } - - got, want := cc, oldCC - // If SetTransportProtocolOption is expected to succeed - // then the returned value for congestion control should - // match the one specified in the - // SetTransportProtocolOption call above, else it should - // be what it was before the call to - // SetTransportProtocolOption. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) - } - }) - } -} - -func TestStackAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Query permitted congestion control algorithms. - var aCC tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) - } - if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) - } -} - -func TestStackSetAvailableCongestionControl(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - s := c.Stack() - - // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.TCPAvailableCongestionControlOption("xyz") - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) - } - - // Verify that we still get the expected list of congestion control options. - var cc tcpip.TCPAvailableCongestionControlOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) - } - if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) - } -} - -func TestEndpointSetCongestionControl(t *testing.T) { - testCases := []struct { - cc tcpip.CongestionControlOption - err tcpip.Error - }{ - {"reno", nil}, - {"cubic", nil}, - {"blahblah", &tcpip.ErrNoSuchFile{}}, - } - - for _, connected := range []bool{false, true} { - for _, tc := range testCases { - t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { - c := context.New(t, 1500) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - var oldCC tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) - } - - if connected { - c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil) - } - - if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { - t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) - } - - var cc tcpip.CongestionControlOption - if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) - } - - got, want := cc, oldCC - // If SetSockOpt is expected to succeed then the - // returned value for congestion control should match - // the one specified in the SetSockOpt above, else it - // should be what it was before the call to SetSockOpt. - if tc.err == nil { - want = tc.cc - } - if got != want { - t.Fatalf("got congestion control = %+v, want = %+v", got, want) - } - }) - } - } -} - -func enableCUBIC(t *testing.T, c *context.Context) { - t.Helper() - opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) - } -} - -func TestKeepalive(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) - } - keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) - } - c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) - - // 5 unacked keepalives are sent. ACK each one, and check that the - // connection stays alive after 5. - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for i := 0; i < 10; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Acknowledge the keepalive. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS, - RcvWnd: 30000, - }) - } - - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Send some data and wait before ACKing it. Keepalives should be disabled - // during this period. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Wait for the packet to be retransmitted. Verify that no keepalives - // were sent. - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), - ), - ) - c.CheckNoPacket("Keepalive packet received while unACKed data is pending") - - next += uint32(len(view)) - - // Send ACK. Keepalives should start sending again. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - // Now receive 5 keepalives, but don't ACK them. The connection - // should be reset after 5. - for i := 0; i < 5; i++ { - b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next-1), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) - - // The connection should be terminated after 5 unacked keepalives. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), - ) - - if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) - } - - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - t.Helper() - // Send a SYN request. - irs = seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptions(), - })) - } - - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { - t.Helper() - // Send a SYN request. - irs = seqnum.Value(context.TestInitialSequenceNumber) - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetV6Packet() - tcp := header.TCP(header.IPv6(b).Payload()) - iss = seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(srcPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - - if synCookieInUse { - // When cookies are in use window scaling is disabled. - tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ - WS: -1, - MSS: c.MSSWithoutOptionsV6(), - })) - } - - checker.IPv6(t, b, checker.TCP(tcpCheckers...)) - - // Send ACK. - c.SendV6Packet(nil, &context.Headers{ - SrcPort: srcPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - return irs, iss -} - -// TestListenBacklogFull tests that netstack does not complete handshakes if the -// listen backlog for the endpoint is full. -func TestListenBacklogFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - // Start listening. - listenBacklog := 10 - if err := c.EP.Listen(listenBacklog); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - } - - time.Sleep(50 * time.Millisecond) - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + lastPortOffset, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } - - // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) - - newEP, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } -} - -// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a -// non unicast IPv4 address are not accepted. -func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") - otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") - subnet := context.StackAddrWithPrefix.Subnet() - subnetBroadcastAddr := subnet.Broadcast() - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - name: "SourceUnspecified", - srcAddr: header.IPv4Any, - dstAddr: context.StackAddr, - }, - { - name: "SourceBroadcast", - srcAddr: header.IPv4Broadcast, - dstAddr: context.StackAddr, - }, - { - name: "SourceOurMulticast", - srcAddr: multicastAddr, - dstAddr: context.StackAddr, - }, - { - name: "SourceOtherMulticast", - srcAddr: otherMulticastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestUnspecified", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Any, - }, - { - name: "DestBroadcast", - srcAddr: context.TestAddr, - dstAddr: header.IPv4Broadcast, - }, - { - name: "DestOurMulticast", - srcAddr: context.TestAddr, - dstAddr: multicastAddr, - }, - { - name: "DestOtherMulticast", - srcAddr: context.TestAddr, - dstAddr: otherMulticastAddr, - }, - { - name: "SrcSubnetBroadcast", - srcAddr: subnetBroadcastAddr, - dstAddr: context.StackAddr, - }, - { - name: "DestSubnetBroadcast", - srcAddr: context.TestAddr, - dstAddr: subnetBroadcastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendPacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestAddr, context.StackAddr) - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } -} - -// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a -// non unicast IPv6 address are not accepted. -func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpiptestutil.MustParse6("ff0e::101") - otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") - - tests := []struct { - name string - srcAddr tcpip.Address - dstAddr tcpip.Address - }{ - { - "SourceUnspecified", - header.IPv6Any, - context.StackV6Addr, - }, - { - "SourceAllNodes", - header.IPv6AllNodesMulticastAddress, - context.StackV6Addr, - }, - { - "SourceOurMulticast", - multicastAddr, - context.StackV6Addr, - }, - { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackV6Addr, - }, - { - "DestUnspecified", - context.TestV6Addr, - header.IPv6Any, - }, - { - "DestAllNodes", - context.TestV6Addr, - header.IPv6AllNodesMulticastAddress, - }, - { - "DestOurMulticast", - context.TestV6Addr, - multicastAddr, - }, - { - "DestOtherMulticast", - context.TestV6Addr, - otherMulticastAddr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateV6Endpoint(true) - - if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { - t.Fatalf("JoinGroup failed: %s", err) - } - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, test.srcAddr, test.dstAddr) - c.CheckNoPacket("Should not have received a response") - - // Handle normal packet. - c.SendV6PacketWithAddrs(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }, context.TestV6Addr, context.StackV6Addr) - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - }) - } -} - -func TestListenSynRcvdQueueFull(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send two SYN's the first one should get a SYN-ACK, the - // second one should not get any response and is dropped as - // the accept queue is full. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcp.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now complete the previous connection. - // Send ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Verify if that is delivered to the accept queue. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - <-ch - - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Try to accept the connections in the backlog. - newEP, _, err := c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - newEP.Write(&r, tcpip.WriteOptions{}) - pkt := c.GetPacket() - tcp = header.IPv4(pkt).Payload() - if string(tcp.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) - } -} - -func TestListenBacklogFullSynCookieInUse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Test for SynCookies usage after filling up the backlog. - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - executeHandshake(t, c, context.TestPort, false) - - // Wait for this to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - - // Send a SYN request. - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - // pick a different src port for new SYN. - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - // The Syn should be dropped as the endpoint's backlog is full. - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Verify that there is only one acceptable connection at this point. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept(nil) - if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - select { - case <-ch: - t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) - case <-time.After(1 * time.Second): - } - } -} - -func TestSYNRetransmit(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send the same SYN packet multiple times. We should still get a valid SYN-ACK - // reply. - irs := seqnum.Value(context.TestInitialSequenceNumber) - for i := 0; i < 5; i++ { - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - } - - // Receive the SYN-ACK reply. - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...)) -} - -func TestSynRcvdBadSeqNumber(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - // Bind to wildcard. - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - // Start listening. - if err := c.EP.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state - irs := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: irs, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - iss := seqnum.Value(tcpHdr.SequenceNumber()) - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs) + 1), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now send a packet with an out-of-window sequence number - largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: largeSeqnum, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - // Should receive an ACK with the expected SEQ number - b = c.GetPacket() - tcpCheckers = []checker.TransportChecker{ - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPAckNum(uint32(irs) + 1), - checker.TCPSeqNum(uint32(iss + 1)), - } - checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - - // Now that the socket replied appropriately with the ACK, - // complete the connection to test that the large SEQ num - // did not change the state from SYN-RCVD. - - // Send ACK to move to ESTABLISHED state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - RcvWnd: 30000, - }) - - newEP, _, err := c.EP.Accept(nil) - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: - t.Fatalf("Accept failed: %s", err) - } - - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now verify that the TCP socket is usable and in a connected state. - data := "Don't panic" - var r strings.Reader - r.Reset(data) - if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - pkt := c.GetPacket() - tcpHdr = header.IPv4(pkt).Payload() - if string(tcpHdr.Payload()) != data { - t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) - } -} - -func TestPassiveConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - stats := c.Stack().Stats() - want := stats.TCP.PassiveConnectionOpenings.Value() + 1 - - srcPort := uint16(context.TestPort) - executeHandshake(t, c, srcPort+1, false) - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { - t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) - } -} - -func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - stats := c.Stack().Stats() - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - c.EP = ep - if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if err := c.EP.Listen(1); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - srcPort := uint16(context.TestPort) - // Now attempt a handshakes it will fill up the accept backlog. - executeHandshake(t, c, srcPort, false) - - // Give time for the final ACK to be processed as otherwise the next handshake could - // get accepted before the previous one based on goroutine scheduling. - time.Sleep(50 * time.Millisecond) - - want := stats.TCP.ListenOverflowSynDrop.Value() + 1 - - // Now we will send one more SYN and this one should get dropped - // Send a SYN request. - c.SendPacket(nil, &context.Headers{ - SrcPort: srcPort + 2, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(context.TestInitialSequenceNumber), - RcvWnd: 30000, - }) - - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) - } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) - } - - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestEndpointBindListenAcceptState(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - ept := endpointTester{ep} - ept.CheckReadError(t, &tcpip.ErrNotConnected{}) - if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { - t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - aep, _, err := ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - aep, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - { - err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { - t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) - } - } - // Listening endpoint remains in listen state. - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - - ep.Close() - // Give worker goroutines time to receive the close notification. - time.Sleep(1 * time.Second) - if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - // Accepted endpoint remains open when the listen endpoint is closed. - if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { - t.Errorf("unexpected endpoint state: want %s, got %s", want, got) - } - -} - -// This test verifies that the auto tuning does not grow the receive buffer if -// the application is not reading the data actively. -func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - stk := c.Stack() - // Set lower limits for auto-tuning tests. This is required because the - // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size defined above. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - - // NOTE: The timestamp values in the sent packets are meaningless to the - // peer so we just increment the timestamp value by 1 every batch as we - // are not really using them for anything. Send a single byte to verify - // the advertised window. - tsVal := rawEP.TSVal + 1 - - // Introduce a 25ms latency by delaying the first byte. - latency := 25 * time.Millisecond - time.Sleep(latency) - // Send an initial payload with atleast segment overhead size. The receive - // window would not grow for smaller segments. - rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal) - - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() - - time.Sleep(25 * time.Millisecond) - - // Allocate a large enough payload for the test. - payloadSize := receiveBufferSize * 2 - b := make([]byte, payloadSize) - - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := 0 - end := payloadSize / 2 - packetsSent := 0 - for ; start < end; start += mss { - packetEnd := start + mss - if start+mss > end { - packetEnd = end - } - rawEP.SendPacketWithTS(b[start:packetEnd], tsVal) - packetsSent++ - } - - // Resume the worker so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Since we sent almost the full receive buffer worth of data (some may have - // been dropped due to segment overheads), we should get a zero window back. - pkt = c.GetPacket() - tcpHdr := header.TCP(header.IPv4(pkt).Payload()) - gotRcvWnd := tcpHdr.WindowSize() - wantAckNum := tcpHdr.AckNumber() - if got, want := int(gotRcvWnd), 0; got != want { - t.Fatalf("got rcvWnd: %d, want: %d", got, want) - } - - time.Sleep(25 * time.Millisecond) - // Verify that sending more data when receiveBuffer is exhausted. - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - - // Now read all the data from the endpoint and verify that advertised - // window increases to the full available buffer size. - for { - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break - } - } - - // Verify that we receive a non-zero window update ACK. When running - // under thread santizer this test can end up sending more than 1 - // ack, 1 for the non-zero window - p := c.GetPacket() - checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(wantAckNum), - func(t *testing.T, h header.Transport) { - tcp, ok := h.(header.TCP) - if !ok { - return - } - // We use 10% here as the error margin upwards as the initial window we - // got was afer 1 segment was already in the receive buffer queue. - tolerance := 1.1 - if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) { - t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance)) - } - }, - )) -} - -// This test verifies that the advertised window is auto-tuned up as the -// application is reading the data that is being received. -func TestReceiveBufferAutoTuning(t *testing.T) { - const mtu = 1500 - const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize - - c := context.New(t, mtu) - defer c.Cleanup() - - // Enable Auto-tuning. - stk := c.Stack() - // Disable out of window rate limiting for this test by setting it to 0 as we - // use out of window ACKs to measure the advertised window. - var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption - if err := stk.SetOption(tcpInvalidRateLimit); err != nil { - t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) - } - - const receiveBufferSize = 80 << 10 // 80KB. - const maxReceiveBufferSize = receiveBufferSize * 10 - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Enable auto-tuning. - { - opt := tcpip.TCPModerateReceiveBufferOption(true) - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - // Change the expected window scale to match the value needed for the - // maximum buffer size used by stack. - c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) - - rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := rawEP.TSVal - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) - curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale - scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> c.WindowScale) - } - // Allocate a large array to send to the endpoint. - b := make([]byte, receiveBufferSize*48) - - // In every iteration we will send double the number of bytes sent in - // the previous iteration and read the same from the app. The received - // window should grow by at least 2x of bytes read by the app in every - // RTT. - offset := 0 - payloadSize := receiveBufferSize / 8 - worker := (c.EP).(interface { - StopWork() - ResumeWork() - }) - latency := 1 * time.Millisecond - for i := 0; i < 5; i++ { - tsVal++ - - // Stop the worker goroutine. - worker.StopWork() - start := offset - end := offset + payloadSize - totalSent := 0 - packetsSent := 0 - for ; start < end; start += mss { - rawEP.SendPacketWithTS(b[start:start+mss], tsVal) - totalSent += mss - packetsSent++ - } - - // Resume it so that it only sees the packets once all of them - // are waiting to be read. - worker.ResumeWork() - - // Give 1ms for the worker to process the packets. - time.Sleep(1 * time.Millisecond) - - lastACK := c.GetPacket() - // Discard any intermediate ACKs and only check the last ACK we get in a - // short time period of few ms. - for { - time.Sleep(1 * time.Millisecond) - pkt := c.GetPacketNonBlocking() - if pkt == nil { - break - } - lastACK = pkt - } - if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want { - t.Fatalf("advertised window got: %d, want <= %d", got, want) - } - - // Now read all the data from the endpoint and invoke the - // moderation API to allow for receive buffer auto-tuning - // to happen before we measure the new window. - totalCopied := 0 - for { - res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - break - } - totalCopied += res.Count - } - - // Invoke the moderation API. This is required for auto-tuning - // to happen. This method is normally expected to be invoked - // from a higher layer than tcpip.Endpoint. So we simulate - // copying to userspace by invoking it explicitly here. - c.EP.ModerateRecvBuf(totalCopied) - - // Now send a keep-alive packet to trigger an ACK so that we can - // measure the new window. - rawEP.NextSeqNum-- - rawEP.SendPacketWithTS(nil, tsVal) - rawEP.NextSeqNum++ - - if i == 0 { - // In the first iteration the receiver based RTT is not - // yet known as a result the moderation code should not - // increase the advertised window. - rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) - } else { - // Read loop above could generate an ACK if the window had dropped to - // zero and then read had opened it up. - lastACK := c.GetPacket() - // Discard any intermediate ACKs and only check the last ACK we get in a - // short time period of few ms. - for { - time.Sleep(1 * time.Millisecond) - pkt := c.GetPacketNonBlocking() - if pkt == nil { - break - } - lastACK = pkt - } - curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale - // If thew new current window is close maxReceiveBufferSize then terminate - // the loop. This can happen before all iterations are done due to timing - // differences when running the test. - if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 { - break - } - // Increase the latency after first two iterations to - // establish a low RTT value in the receiver since it - // only tracks the lowest value. This ensures that when - // ModerateRcvBuf is called the elapsed time is always > - // rtt. Without this the test is flaky due to delays due - // to scheduling/wakeup etc. - latency += 50 * time.Millisecond - } - time.Sleep(latency) - offset += payloadSize - payloadSize *= 2 - } - // Check that at the end of our iterations the receive window grew close to the maximum - // permissible size of maxReceiveBufferSize/2 - if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want { - t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want) - } - -} - -func TestDelayEnabled(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - checkDelayOption(t, c, false, false) // Delay is disabled by default. - - for _, delayEnabled := range []bool{false, true} { - t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - opt := tcpip.TCPDelayEnabled(delayEnabled) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) - } - checkDelayOption(t, c, opt, delayEnabled) - }) - } -} - -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { - t.Helper() - - var gotDelayEnabled tcpip.TCPDelayEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { - t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) - } - if gotDelayEnabled != wantDelayEnabled { - t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) - } - - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) - if err != nil { - t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) - } - gotDelayOption := ep.SocketOptions().GetDelayOption() - if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) - } -} - -func TestTCPLingerTimeout(t *testing.T) { - c := context.New(t, 1500 /* mtu */) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - testCases := []struct { - name string - tcpLingerTimeout time.Duration - want time.Duration - }{ - {"NegativeLingerTimeout", -123123, -1}, - // Zero is treated same as the stack's default TCP_LINGER2 timeout. - {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, - {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, - // Values > stack's TCPLingerTimeout are capped to the stack's - // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) - } - - v = 0 - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(&%T) = %s", v, err) - } - if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("got linger timeout = %s, want = %s", got, want) - } - }) - } -} - -func TestTCPTimeWaitRSTIgnored(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now send a RST and this should be ignored and not - // generate an ACK. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagRst, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - }) - - c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitOutOfOrder(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Out of order ACK should generate an immediate ACK in - // TIME_WAIT. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 3, - }) - - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) -} - -func TestTCPTimeWaitNewSyn(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Send a SYN request w/ sequence number lower than - // the highest sequence number sent. We just reuse - // the same number. - iss = seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) - - // drain any older notifications from the notification channel before attempting - // 2nd connection. - select { - case <-ch: - default: - } - - // Send a SYN request w/ sequence number higher than - // the highest sequence number sent. - iss = iss.Add(3) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b = c.GetPacket() - tcpHdr = header.IPv4(b).Payload() - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } -} - -func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } - - want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - c.EP.Close() - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+1), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - time.Sleep(2 * time.Second) - - // Now send a duplicate FIN. This should cause the TIME_WAIT to extend - // by another 5 seconds and also send us a duplicate ACK as it should - // indicate that the final ACK was potentially lost. - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+2)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Sleep for 4 seconds so at this point we are 1 second past the - // original tcpLingerTimeout of 5 seconds. - time.Sleep(4 * time.Second) - - // Send an ACK and it should not generate any packet as the socket - // should still be in TIME_WAIT for another another 5 seconds due - // to the duplicate FIN we sent earlier. - *ackHeaders = *finHeaders - ackHeaders.SeqNum = ackHeaders.SeqNum + 1 - ackHeaders.Flags = header.TCPFlagAck - c.SendPacket(nil, ackHeaders) - - c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) - // Now sleep for another 2 seconds so that we are past the - // extended TIME_WAIT of 7 seconds (2 + 5). - time.Sleep(2 * time.Second) - - // Resend the same ACK. - c.SendPacket(nil, ackHeaders) - - // Receive the RST that should be generated as there is no valid - // endpoint. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) - - if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) - } -} - -func TestTCPCloseWithData(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed - // after 5 seconds in TIME_WAIT state. - tcpTimeWaitTimeout := 5 * time.Second - opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) - } - - wq := &waiter.Queue{} - ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %s", err) - } - - // Send a SYN request. - iss := seqnum.Value(context.TestInitialSequenceNumber) - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - }) - - // Receive the SYN-ACK reply. - b := c.GetPacket() - tcpHdr := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - ackHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: 30000, - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - - // Now trigger a passive close by sending a FIN. - finHeaders := &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck | header.TCPFlagFin, - SeqNum: iss + 1, - AckNum: c.IRS + 2, - RcvWnd: 30000, - } - - c.SendPacket(nil, finHeaders) - - // Get the ACK to the FIN we just sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(iss)+2), - checker.TCPFlags(header.TCPFlagAck))) - - // Now write a few bytes and then close the endpoint. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - // Check that data is received. - b = c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { - t.Errorf("got data = %x, want = %x", p, data) - } - - c.EP.Close() - // Check the FIN. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))), - checker.TCPAckNum(uint32(iss+2)), - checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - - // First send a partial ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now send a full ACK. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Now ACK the FIN. - ackHeaders.AckNum++ - c.SendPacket(nil, ackHeaders) - - // Now send an ACK and we should get a RST back as the endpoint should - // be in CLOSED state. - ackHeaders = &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 2, - AckNum: c.IRS + 1 + seqnum.Value(len(data)), - RcvWnd: 30000, - } - c.SendPacket(nil, ackHeaders) - - // Check the RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(ackHeaders.AckNum)), - checker.TCPAckNum(0), - checker.TCPFlags(header.TCPFlagRst))) -} - -func TestTCPUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventHUp) - defer c.WQ.EventUnregister(&waitEntry) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - // Ensure that on the next retransmit timer fire, the user timeout has - // expired. - initRTO := 1 * time.Second - userTimeout := initRTO / 2 - v := tcpip.TCPUserTimeoutOption(userTimeout) - if err := c.EP.SetSockOpt(&v); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) - } - - // Send some data and wait before ACKing it. - view := make([]byte, 3) - var r bytes.Reader - r.Reset(view) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %s", err) - } - - next := uint32(c.IRS) + 1 - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(len(view)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - // Wait for the retransmit timer to be fired and the user timeout to cause - // close of the connection. - select { - case <-notifyCh: - case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) - } - - // No packet should be received as the connection should be silently - // closed due to timeout. - c.CheckNoPacket("unexpected packet received after userTimeout has expired") - - next += uint32(len(view)) - - // The connection should be terminated after userTimeout has expired. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: seqnum.Value(next), - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(next), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestKeepaliveWithUserTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */) - - origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() - - const keepAliveIdle = 100 * time.Millisecond - const keepAliveInterval = 3 * time.Second - keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) - if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) - } - keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) - } - if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { - t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) - } - c.EP.SocketOptions().SetKeepAlive(true) - - // Set userTimeout to be the duration to be 1 keepalive - // probes. Which means that after the first probe is sent - // the second one should cause the connection to be - // closed due to userTimeout being hit. - userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) - if err := c.EP.SetSockOpt(&userTimeout); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) - } - - // Check that the connection is still alive. - ept := endpointTester{c.EP} - ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) - - // Now receive 1 keepalives, but don't ACK it. - b := c.GetPacket() - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)), - checker.TCPAckNum(uint32(iss)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - - // Sleep for a litte over the KeepAlive interval to make sure - // the timer has time to fire after the last ACK and close the - // close the socket. - time.Sleep(keepAliveInterval + keepAliveInterval/2) - - // The connection should be closed with a timeout. - // Send an ACK to trigger a RST from the stack as the endpoint should - // be dead. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS + 1, - RcvWnd: 30000, - }) - - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS+1)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), - ) - - ept.CheckReadError(t, &tcpip.ErrTimeout{}) - if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { - t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) - } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) - } -} - -func TestIncreaseWindowOnRead(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after read() when the window grows by more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf * 2 - sent := 0 - data := make([]byte, defaultMTU/2) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - pkt := c.GetPacket() - checker.IPv4(t, pkt, - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - // Break once the window drops below defaultMTU/2 - if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 { - break - } - } - - // We now have < 1 MSS in the buffer space. Read at least > 2 MSS - // worth of data as receive buffer space - w := tcpip.LimitedWriter{ - W: ioutil.Discard, - // defaultMTU is a good enough estimate for the MSS used for this - // connection. - N: defaultMTU * 2, - } - for w.N != 0 { - _, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } - - // After reading > MSS worth of data, we surely crossed MSS. See the ack: - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestIncreaseWindowOnBufferResize(t *testing.T) { - // This test ensures that the endpoint sends an ack, - // after available recv buffer grows to more than 1 MSS. - c := context.New(t, defaultMTU) - defer c.Cleanup() - - const rcvBuf = 65535 * 10 - c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf) - - // Write chunks of ~30000 bytes. It's important that two - // payloads make it equal or longer than MSS. - remain := rcvBuf - sent := 0 - data := make([]byte, defaultMTU/2) - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - for remain > len(data) { - c.SendPacket(data, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss.Add(seqnum.Size(sent)), - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - }) - sent += len(data) - remain -= len(data) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindowLessThanEq(0xffff), - checker.TCPFlags(header.TCPFlagAck), - ), - ) - } - - // Increasing the buffer from should generate an ACK, - // since window grew from small value to larger equal MSS - c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true) - checker.IPv4(t, c.GetPacket(), - checker.PayloadLen(header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+uint32(sent)), - checker.TCPWindow(uint16(0xffff)), - checker.TCPFlags(header.TCPFlagAck), - ), - ) -} - -func TestTCPDeferAccept(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - _, _, err := c.EP.Accept(nil) - if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { - t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) - } - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give a bit of time for the socket to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestTCPDeferAcceptTimeout(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.Create(-1) - - if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatal("Bind failed:", err) - } - - if err := c.EP.Listen(10); err != nil { - t.Fatal("Listen failed:", err) - } - - const tcpDeferAccept = 1 * time.Second - tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) - if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) - } - - irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - - _, _, err := c.EP.Accept(nil) - if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { - t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) - } - - // Sleep for a little of the tcpDeferAccept timeout. - time.Sleep(tcpDeferAccept + 100*time.Millisecond) - - // On timeout expiry we should get a SYN-ACK retransmission. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), - checker.TCPAckNum(uint32(irs)+1))) - - // Send data. This should result in an acceptable endpoint. - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: context.StackPort, - Flags: header.TCPFlagAck, - SeqNum: irs + 1, - AckNum: iss + 1, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) - - // Give sometime for the endpoint to be delivered to the accept queue. - time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept(nil) - if err != nil { - t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) - } - - aep.Close() - // Closing aep without reading the data should trigger a RST. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.SrcPort(context.StackPort), - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), - checker.TCPSeqNum(uint32(iss+1)), - checker.TCPAckNum(uint32(irs+5)))) -} - -func TestResetDuringClose(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */) - // Send some data to make sure there is some unread - // data to trigger a reset on c.Close. - irs := c.IRS - iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1) - c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: iss, - AckNum: irs.Add(1), - RcvWnd: 30000, - }) - - // Receive ACK for the data we sent. - checker.IPv4(t, c.GetPacket(), checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(irs.Add(1))), - checker.TCPAckNum(uint32(iss)+4))) - - // Close in a separate goroutine so that we can trigger - // a race with the RST we send below. This should not - // panic due to the route being released depeding on - // whether Close() sends an active RST or the RST sent - // below is processed by the worker first. - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, - DstPort: c.Port, - SeqNum: iss.Add(4), - AckNum: c.IRS.Add(5), - RcvWnd: 30000, - Flags: header.TCPFlagRst, - }) - }() - - wg.Add(1) - go func() { - defer wg.Done() - c.EP.Close() - }() - - wg.Wait() -} - -func TestStackTimeWaitReuse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - s := c.Stack() - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err) - } - if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) - } -} - -func TestSetStackTimeWaitReuse(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - s := c.Stack() - testCases := []struct { - v int - err tcpip.Error - }{ - {int(tcpip.TCPTimeWaitReuseDisabled), nil}, - {int(tcpip.TCPTimeWaitReuseGlobal), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, - } - - for _, tc := range testCases { - opt := tcpip.TCPTimeWaitReuseOption(tc.v) - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) - if got, want := err, tc.err; got != want { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) - } - if tc.err != nil { - continue - } - - var twReuse tcpip.TCPTimeWaitReuseOption - if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err) - } - - if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want { - t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want) - } - } -} - -// generateRandomPayload generates a random byte slice of the specified length -// causing a fatal test failure if it is unable to do so. -func generateRandomPayload(t *testing.T, n int) []byte { - t.Helper() - buf := make([]byte, n) - if _, err := rand.Read(buf); err != nil { - t.Fatalf("rand.Read(buf) failed: %s", err) - } - return buf -} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go deleted file mode 100644 index 1deb1fe4d..000000000 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_test - -import ( - "bytes" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" - "gvisor.dev/gvisor/pkg/waiter" -) - -// createConnectedWithTimestampOption creates and connects c.ep with the -// timestamp option enabled. -func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { - return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) -} - -// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on -// an active connect and sets the TS Echo Reply fields correctly when the -// SYN-ACK also indicates support for the TS option and provides a TSVal. -func TestTimeStampEnabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read and validate that we have data to read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // The following tests ensure that TS option once enabled behaves - // correctly as described in - // https://tools.ietf.org/html/rfc7323#section-4.3. - // - // We are not testing delayed ACKs here, but we do test out of order - // packet delivery and filling the sequence number hole created due to - // the out of order packet. - // - // The test also verifies that the sequence numbers and timestamps are - // as expected. - data := []byte{1, 2, 3} - - // First we increment tsVal by a small amount. - tsVal := rep.TSVal + 100 - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Next we send an out of order packet. - rep.NextSeqNum += 3 - tsVal += 200 - rep.SendPacketWithTS(data, tsVal) - - // The ACK should contain the original sequenceNumber and an older TS. - rep.NextSeqNum -= 6 - rep.VerifyACKWithTS(tsVal - 200) - - // Next we fill the hole and the returned ACK should contain the - // cumulative sequence number acking all data sent till now and have the - // latest timestamp sent below in its TSEcr field. - tsVal -= 100 - rep.SendPacketWithTS(data, tsVal) - rep.NextSeqNum += 3 - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal by a large value that doesn't result in a wrap around. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - // Increment tsVal again by a large value which should cause the - // timestamp value to wrap around. The returned ACK should contain the - // wrapped around timestamp in its tsEcr field and not the tsVal from - // the previous packet sent above. - tsVal += 0x7fffffff - rep.SendPacketWithTS(data, tsVal) - rep.VerifyACKWithTS(tsVal) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // There should be 5 views to read and each of them should - // contain the same data. - for i := 0; i < 5; i++ { - buf := make([]byte, len(data)) - w := tcpip.SliceWriter(buf) - result, err := c.EP.Read(&w, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: len(buf), - Total: len(buf), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if got, want := buf, data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } - } -} - -// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an -// active connect but if the SYN-ACK doesn't specify the TS option then -// timestamp option is not enabled and future packets do not contain a -// timestamp. -func TestTimeStampDisabledConnect(t *testing.T) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - c.CreateConnectedWithOptions(header.TCPSynOptions{}) -} - -func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - tsVal := rand.Uint32() - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) - - // Now send some data and validate that timestamp is echoed correctly in the ACK. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %s", err) - } - - // Check that data is received and that the timestamp option TSEcr field - // matches the expected value. - b := c.GetPacket() - checker.IPv4(t, b, - // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 - // byte boundary. - checker.PayloadLen(len(data)+header.TCPMinimumSize+12), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - checker.TCPTimestampChecker(true, 0, tsVal+1), - ), - ) -} - -// TestTimeStampEnabledAccept tests that if the SYN on a passive connect -// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK -// and echoes the tsVal field of the original SYN in the tcEcr field of the -// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify -// that Timestamp option is enabled in both cases if requested in the original -// SYN. -func TestTimeStampEnabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { - c := context.New(t, defaultMTU) - defer c.Cleanup() - - if cookieEnabled { - opt := tcpip.TCPAlwaysUseSynCookies(true) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) - } - } - - t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) - c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) - - // Now send some data with the accepted connection endpoint and validate - // that no timestamp option is sent in the TCP segment. - data := []byte{1, 2, 3} - - var r bytes.Reader - r.Reset(data) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %s", err) - } - - // Check that data is received and that the timestamp option is disabled - // when SYN cookies are enabled/disabled. - b := c.GetPacket() - checker.IPv4(t, b, - checker.PayloadLen(len(data)+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(790), - checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - checker.TCPTimestampChecker(false, 0, 0), - ), - ) -} - -// TestTimeStampDisabledAccept tests that Timestamp option is not used when the -// peer doesn't advertise it and connection is established with Accept(). -func TestTimeStampDisabledAccept(t *testing.T) { - testCases := []struct { - cookieEnabled bool - wndScale int - wndSize uint16 - }{ - {true, -1, 0xffff}, // When cookie is used window scaling is disabled. - // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of - // that. - {false, 5, 0x4000}, - } - for _, tc := range testCases { - timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) - } -} - -func TestSendGreaterThanMTUWithOptions(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - createConnectedWithTimestampOption(c) - testBrokenUpWrite(t, c, maxPayload) -} - -func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { - const maxPayload = 100 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) - defer c.Cleanup() - - rep := createConnectedWithTimestampOption(c) - - // Register for read. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - droppedPacketsStat := c.Stack().Stats().DroppedPackets - droppedPackets := droppedPacketsStat.Value() - data := []byte{1, 2, 3} - // Send a packet with no TCP options/timestamp. - rep.SendPacket(data, nil) - - select { - case <-ch: - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - - // Assert that DroppedPackets was not incremented. - if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { - t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) - } - - // Issue a read and we should data. - var buf bytes.Buffer - result, err := c.EP.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("Read: unexpected result (-want +got):\n%s", diff) - } - if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { - t.Fatalf("Data is different: got: %v, want: %v", got, want) - } -} diff --git a/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go new file mode 100644 index 000000000..4cb82fcc9 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tcp diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD deleted file mode 100644 index ce6a2c31d..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "context", - testonly = 1, - srcs = ["context.go"], - visibility = [ - "//visibility:public", - ], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/seqnum", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/tcp", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go deleted file mode 100644 index 53efecc5a..000000000 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ /dev/null @@ -1,1233 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provides a test context for use in tcp tests. It also -// provides helper methods to assert/check certain behaviours. -package context - -import ( - "bytes" - "context" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - // StackAddr is the IPv4 address assigned to the stack. - StackAddr = "\x0a\x00\x00\x01" - - // StackPort is used as the listening port in tests for passive - // connects. - StackPort = 1234 - - // TestAddr is the source address for packets sent to the stack via the - // link layer endpoint. - TestAddr = "\x0a\x00\x00\x02" - - // TestPort is the TCP port used for packets sent to the stack - // via the link layer endpoint. - TestPort = 4096 - - // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - - // TestV6Addr is the source address for packets sent to the stack via - // the link layer endpoint. - TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - - // StackV4MappedAddr is StackAddr as a mapped v6 address. - StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr - - // TestV4MappedAddr is TestAddr as a mapped v6 address. - TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr - - // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - // TestInitialSequenceNumber is the initial sequence number sent in packets that - // are sent in response to a SYN or in the initial SYN sent to the stack. - TestInitialSequenceNumber = 789 -) - -// StackAddrWithPrefix is StackAddr with its associated prefix length. -var StackAddrWithPrefix = tcpip.AddressWithPrefix{ - Address: StackAddr, - PrefixLen: 24, -} - -// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. -var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ - Address: StackV6Addr, - PrefixLen: header.IIDOffsetInIPv6Address * 8, -} - -// Headers is used to represent the TCP header fields when building a -// new packet. -type Headers struct { - // SrcPort holds the src port value to be used in the packet. - SrcPort uint16 - - // DstPort holds the destination port value to be used in the packet. - DstPort uint16 - - // SeqNum is the value of the sequence number field in the TCP header. - SeqNum seqnum.Value - - // AckNum represents the acknowledgement number field in the TCP header. - AckNum seqnum.Value - - // Flags are the TCP flags in the TCP header. - Flags header.TCPFlags - - // RcvWnd is the window to be advertised in the ReceiveWindow field of - // the TCP header. - RcvWnd seqnum.Size - - // TCPOpts holds the options to be sent in the option field of the TCP - // header. - TCPOpts []byte -} - -// Options contains options for creating a new test context. -type Options struct { - // EnableV4 indicates whether IPv4 should be enabled. - EnableV4 bool - - // EnableV6 indicates whether IPv4 should be enabled. - EnableV6 bool - - // MTU indicates the maximum transmission unit on the link layer. - MTU uint32 -} - -// Context provides an initialized Network stack and a link layer endpoint -// for use in TCP tests. -type Context struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - // IRS holds the initial sequence number in the SYN sent by endpoint in - // case of an active connect or the sequence number sent by the endpoint - // in the SYN-ACK sent in response to a SYN when listening in passive - // mode. - IRS seqnum.Value - - // Port holds the port bound by EP below in case of an active connect or - // the listening port number in case of a passive connect. - Port uint16 - - // EP is the test endpoint in the stack owned by this context. This endpoint - // is used in various tests to either initiate an active connect or is used - // as a passive listening endpoint to accept inbound connections. - EP tcpip.Endpoint - - // Wq is the wait queue associated with EP and is used to block for events - // on EP. - WQ waiter.Queue - - // TimeStampEnabled is true if ep is connected with the timestamp option - // enabled. - TimeStampEnabled bool - - // WindowScale is the expected window scale in SYN packets sent by - // the stack. - WindowScale uint8 - - // RcvdWindowScale is the actual window scale sent by the stack in - // SYN/SYN-ACK. - RcvdWindowScale uint8 -} - -// New allocates and initializes a test context containing a new -// stack and a link-layer endpoint. -func New(t *testing.T, mtu uint32) *Context { - return NewWithOpts(t, Options{ - EnableV4: true, - EnableV6: true, - MTU: mtu, - }) -} - -// NewWithOpts allocates and initializes a test context containing a new -// stack and a link-layer endpoint with specific options. -func NewWithOpts(t *testing.T, opts Options) *Context { - if opts.MTU == 0 { - panic("MTU must be greater than 0") - } - - stackOpts := stack.Options{ - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - } - if opts.EnableV4 { - stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) - } - if opts.EnableV6 { - stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) - } - s := stack.New(stackOpts) - - const sendBufferSize = 1 << 20 // 1 MiB - const recvBufferSize = 1 << 20 // 1 MiB - // Allow minimum send/receive buffer sizes to be 1 during tests. - sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) - } - - rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) - } - - // Increase minimum RTO in tests to avoid test flakes due to early - // retransmit in case the test executors are overloaded and cause timers - // to fire earlier than expected. - minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { - t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) - } - - // Some of the congestion control tests send up to 640 packets, we so - // set the channel size to 1000. - ep := channel.New(1000, opts.MTU, "") - wep := stack.LinkEndpoint(ep) - if testing.Verbose() { - wep = sniffer.New(ep) - } - nicOpts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) - if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) - } - opts2 := stack.NICOptions{Name: "nic2"} - if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) - } - - var routeTable []tcpip.Route - - if opts.EnableV4 { - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - routeTable = append(routeTable, tcpip.Route{ - Destination: header.IPv4EmptySubnet, - NIC: 1, - }) - } - - if opts.EnableV6 { - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } - routeTable = append(routeTable, tcpip.Route{ - Destination: header.IPv6EmptySubnet, - NIC: 1, - }) - } - - s.SetRouteTable(routeTable) - - return &Context{ - t: t, - s: s, - linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), - } -} - -// Cleanup closes the context endpoint if required. -func (c *Context) Cleanup() { - if c.EP != nil { - c.EP.Close() - } - c.Stack().Close() -} - -// Stack returns a reference to the stack in the Context. -func (c *Context) Stack() *stack.Stack { - return c.s -} - -// CheckNoPacketTimeout verifies that no packet is received during the time -// specified by wait. -func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), wait) - defer cancel() - if _, ok := c.linkEP.ReadContext(ctx); ok { - c.t.Fatal(errMsg) - } -} - -// CheckNoPacket verifies that no packet is received for 1 second. -func (c *Context) CheckNoPacket(errMsg string) { - c.CheckNoPacketTimeout(errMsg, 1*time.Second) -} - -// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. If no packet is received in the specified timeout it will return -// nil. -func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - // Just check that the stack set the transport protocol number for outbound - // TCP messages. - // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part - // of the headerinfo. - if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { - c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize) - } - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// GetPacket reads a packet from the link layer endpoint and verifies -// that it is an IPv4 packet with the expected source and destination -// addresses. -func (c *Context) GetPacket() []byte { - c.t.Helper() - - p := c.GetPacketWithTimeout(5 * time.Second) - if p == nil { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - return p -} - -// GetPacketNonBlocking reads a packet from the link layer endpoint -// and verifies that it is an IPv4 packet with the expected source -// and destination address. If no packet is available it will return -// nil immediately. -func (c *Context) GetPacketNonBlocking() []byte { - c.t.Helper() - - p, ok := c.linkEP.Read() - if !ok { - return nil - } - - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - - // Just check that the stack set the transport protocol number for outbound - // TCP messages. - // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part - // of the headerinfo. - if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber { - c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b -} - -// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. -func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) { - // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) - if len(buf) > maxTotalSize { - buf = buf[:maxTotalSize] - } - - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(header.ICMPv4ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) - icmp.SetType(typ) - icmp.SetCode(code) - const icmpv4VariableHeaderOffset = 4 - copy(icmp[icmpv4VariableHeaderOffset:], p1) - copy(icmp[header.ICMPv4PayloadOffset:], p2) - icmp.SetChecksum(0) - checksum := ^header.Checksum(icmp, 0 /* initial */) - icmp.SetChecksum(checksum) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// BuildSegment builds a TCP segment based on the given Headers and payload. -func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { - return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) -} - -// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, -// payload and source and destination IPv4 addresses. -func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv4MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: h.Flags, - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - return buf.ToVectorisedView() -} - -// SendSegment sends a TCP segment that has already been built and written to a -// buffer.VectorisedView. -func (c *Context) SendSegment(s buffer.VectorisedView) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: s, - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendPacket builds and sends a TCP segment(with the provided payload & TCP -// headers) in an IPv4 packet via the link layer endpoint. -func (c *Context) SendPacket(payload []byte, h *Headers) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: c.BuildSegment(payload, h), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload -// & TCPheaders) in an IPv4 packet via the link layer endpoint using the -// provided source and destination IPv4 addresses. -func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: c.BuildSegmentWithAddrs(payload, h, src, dst), - }) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) -} - -// SendAck sends an ACK packet. -func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { - c.SendAckWithSACK(seq, bytesReceived, nil) -} - -// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. -func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { - options := make([]byte, 40) - offset := 0 - if len(sackBlocks) > 0 { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) - } - - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: c.Port, - Flags: header.TCPFlagAck, - SeqNum: seq, - AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), - RcvWnd: 30000, - TCPOpts: options[:offset], - }) -} - -// ReceiveAndCheckPacket reads a packet from the link layer endpoint and -// verifies that the packet packet payload of packet matches the slice -// of data indicated by offset & size. -func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { - c.t.Helper() - - c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) -} - -// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size and skips optlen bytes in addition to the IP -// TCP headers when comparing the data. -func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { - c.t.Helper() - - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize+optlen), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } -} - -// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint -// and verifies that the packet packet payload of packet matches the slice of -// data indicated by offset & size. It returns true if a packet was received and -// processed. -func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { - c.t.Helper() - - b := c.GetPacketNonBlocking() - if b == nil { - return false - } - checker.IPv4(c.t, b, - checker.PayloadLen(size+header.TCPMinimumSize), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), - ), - ) - - pdata := data[offset:][:size] - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { - c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) - } - return true -} - -// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only -// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 -// only endpoint instead of a default dual stack socket. -func (c *Context) CreateV6Endpoint(v6only bool) { - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - c.EP.SocketOptions().SetV6Only(v6only) -} - -// GetV6Packet reads a single packet from the link layer endpoint of the context -// and asserts that it is an IPv6 Packet with the expected src/dest addresses. -func (c *Context) GetV6Packet() []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b -} - -// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of -// the context. -func (c *Context) SendV6Packet(payload []byte, h *Headers) { - c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) -} - -// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer -// endpoint of the context using the provided source and destination IPv6 -// addresses. -func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - TransportProtocol: tcp.ProtocolNumber, - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, - }) - - // Initialize the TCP header. - t := header.TCP(buf[header.IPv6MinimumSize:]) - t.Encode(&header.TCPFields{ - SrcPort: h.SrcPort, - DstPort: h.DstPort, - SeqNum: uint32(h.SeqNum), - AckNum: uint32(h.AckNum), - DataOffset: header.TCPMinimumSize, - Flags: h.Flags, - WindowSize: uint16(h.RcvWnd), - }) - - // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) - - // Calculate the TCP checksum and set it. - xsum = header.Checksum(payload, xsum) - t.SetChecksum(^t.CalculateChecksum(xsum)) - - // Inject packet. - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - }) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) -} - -// CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { - c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) -} - -// Connect performs the 3-way handshake for c.EP with the provided Initial -// Sequence Number (iss) and receive window(rcvWnd) and any options if -// specified. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -// -// PreCondition: c.EP must already be created. -func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { - c.t.Helper() - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.WritableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - c.t.Fatalf("Unexpected return value from Connect: %v", err) - } - - // Receive SYN packet. - b := c.GetPacket() - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpHdr := header.TCP(header.IPv4(b).Payload()) - synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */) - c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) - - c.SendPacket(nil, &Headers{ - SrcPort: tcpHdr.DestinationPort(), - DstPort: tcpHdr.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: rcvWnd, - TCPOpts: options, - }) - - // Receive ACK packet. - checker.IPv4(c.t, c.GetPacket(), - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS)+1), - checker.TCPAckNum(uint32(iss)+1), - ), - ) - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.LastError(); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - c.RcvdWindowScale = uint8(synOpts.WS) - c.Port = tcpHdr.SourcePort() -} - -// Create creates a TCP endpoint. -func (c *Context) Create(epRcvBuf int) { - // Create TCP endpoint. - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if epRcvBuf != -1 { - c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */) - } -} - -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { - c.Create(epRcvBuf) - c.Connect(iss, rcvWnd, options) -} - -// RawEndpoint is just a small wrapper around a TCP endpoint's state to make -// sending data and ACK packets easy while being able to manipulate the sequence -// numbers and timestamp values as needed. -type RawEndpoint struct { - C *Context - SrcPort uint16 - DstPort uint16 - Flags header.TCPFlags - NextSeqNum seqnum.Value - AckNum seqnum.Value - WndSize seqnum.Size - RecentTS uint32 // Stores the latest timestamp to echo back. - TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. - - // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. - SACKPermitted bool -} - -// SendPacketWithTS embeds the provided tsVal in the Timestamp option -// for the packet to be sent out. -func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { - r.TSVal = tsVal - tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) - r.SendPacket(payload, tsOpt[:]) -} - -// SendPacket is a small wrapper function to build and send packets. -func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { - packetHeaders := &Headers{ - SrcPort: r.SrcPort, - DstPort: r.DstPort, - Flags: r.Flags, - SeqNum: r.NextSeqNum, - AckNum: r.AckNum, - RcvWnd: r.WndSize, - TCPOpts: opts, - } - r.C.SendPacket(payload, packetHeaders) - r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) -} - -// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches -// the provided tsVal as well as returns the original packet. -func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte { - r.C.t.Helper() - // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPTimestampChecker(true, 0, tsVal), - ), - ) - // Store the parsed TSVal from the ack as recentTS. - tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) - opts := tcpSeg.ParsedOptions() - r.RecentTS = opts.TSVal - return ackPacket -} - -// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided -// tsVal. -func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { - r.C.t.Helper() - _ = r.VerifyAndReturnACKWithTS(tsVal) -} - -// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK -// matches the provided rcvWnd. -func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { - r.C.t.Helper() - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPWindow(rcvWnd), - ), - ) -} - -// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. -func (r *RawEndpoint) VerifyACKNoSACK() { - r.VerifyACKHasSACK(nil) -} - -// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. -func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { - // Read ACK and verify that the TCP options in the segment do - // not contain a SACK block. - ackPacket := r.C.GetPacket() - checker.IPv4(r.C.t, ackPacket, - checker.TCP( - checker.DstPort(r.SrcPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(r.AckNum)), - checker.TCPAckNum(uint32(r.NextSeqNum)), - checker.TCPSACKBlockChecker(sackBlocks), - ), - ) -} - -// CreateConnectedWithOptions creates and connects c.ep with the specified TCP -// options enabled and returns a RawEndpoint which represents the other end of -// the connection. -// -// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK -// does not carry an option that was not requested. -func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Start connection attempt. - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.WritableEvents) - defer c.WQ.EventUnregister(&waitEntry) - - testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} - err = c.EP.Connect(testFullAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) - } - // Receive SYN packet. - b := c.GetPacket() - // Validate that the syn has the timestamp option and a valid - // TS value. - mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) - - checker.IPv4(c.t, b, - checker.TCP( - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagSyn), - checker.TCPSynOptions(header.TCPSynOptions{ - MSS: mss, - TS: true, - WS: int(c.WindowScale), - SACKPermitted: c.SACKEnabled(), - }), - ), - ) - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - tcpSeg := header.TCP(header.IPv4(b).Payload()) - synOptions := header.ParseSynOptions(tcpSeg.Options(), false) - - // Build options w/ tsVal to be sent in the SYN-ACK. - synAckOptions := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - if wantOptions.WS != -1 { - offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:]) - } - if wantOptions.TS { - offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) - } - if wantOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) - } - - offset += header.AddTCPOptionPadding(synAckOptions, offset) - - // Build SYN-ACK. - c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(TestInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagSyn | header.TCPFlagAck, - SeqNum: iss, - AckNum: c.IRS.Add(1), - RcvWnd: 30000, - TCPOpts: synAckOptions[:offset], - }) - - // Read ACK. - ackPacket := c.GetPacket() - - // Verify TCP header fields. - tcpCheckers := []checker.TransportChecker{ - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.TCPSeqNum(uint32(c.IRS) + 1), - checker.TCPAckNum(uint32(iss) + 1), - } - - // Verify that tsEcr of ACK packet is wantOptions.TSVal if the - // timestamp option was enabled, if not then we verify that - // there is no timestamp in the ACK packet. - if wantOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) - - ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) - ackOptions := ackSeg.ParsedOptions() - - // Wait for connection to be established. - select { - case <-notifyCh: - if err := c.EP.LastError(); err != nil { - c.t.Fatalf("Unexpected error when connecting: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for connection") - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) - } - - // Store the source port in use by the endpoint. - c.Port = tcpSeg.SourcePort() - - // Mark in context that timestamp option is enabled for this endpoint. - c.TimeStampEnabled = true - c.RcvdWindowScale = uint8(synOptions.WS) - return &RawEndpoint{ - C: c, - SrcPort: tcpSeg.DestinationPort(), - DstPort: tcpSeg.SourcePort(), - Flags: header.TCPFlagAck | header.TCPFlagPsh, - NextSeqNum: iss + 1, - AckNum: c.IRS.Add(1), - WndSize: 30000, - RecentTS: ackOptions.TSVal, - TSVal: wantOptions.TSVal, - SACKPermitted: wantOptions.SACKPermitted, - } -} - -// AcceptWithOptions initializes a listening endpoint and connects to it with the -// provided options enabled. It also verifies that the SYN-ACK has the expected -// values for the provided options. -// -// The function returns a RawEndpoint representing the other end of the accepted -// endpoint. -func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - // Create EP and start listening. - wq := &waiter.Queue{} - ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - - if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if err := ep.Listen(10); err != nil { - c.t.Fatalf("Listen failed: %v", err) - } - if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) - - // Try to accept the connection. - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - - c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for connection to be established. - select { - case <-ch: - c.EP, _, err = ep.Accept(nil) - if err != nil { - c.t.Fatalf("Accept failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for accept") - } - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - return rep -} - -// PassiveConnect just disables WindowScaling and delegates the call to -// PassiveConnectWithOptions. -func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { - synOptions.WS = -1 - c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) -} - -// PassiveConnectWithOptions initiates a new connection (with the specified TCP -// options enabled) to the port on which the Context.ep is listening for new -// connections. It also validates that the SYN-ACK has the expected values for -// the enabled options. -// -// NOTE: MSS is not a negotiated option and it can be asymmetric -// in each direction. This function uses the maxPayload to set the MSS to be -// sent to the peer on a connect and validates that the MSS in the SYN-ACK -// response is equal to the MTU - (tcphdr len + iphdr len). -// -// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the -// value of the window scaling option to be sent in the SYN. If synOptions.WS > -// 0 then we send the WindowScale option. -func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { - c.t.Helper() - opts := make([]byte, header.TCPOptionsMaximumSize) - offset := 0 - offset += header.EncodeMSSOption(uint32(maxPayload), opts) - - if synOptions.WS >= 0 { - offset += header.EncodeWSOption(3, opts[offset:]) - } - if synOptions.TS { - offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) - } - - if synOptions.SACKPermitted { - offset += header.EncodeSACKPermittedOption(opts[offset:]) - } - - paddingToAdd := 4 - offset%4 - // Now add any padding bytes that might be required to quad align the - // options. - for i := offset; i < offset+paddingToAdd; i++ { - opts[i] = header.TCPOptionNOP - } - offset += paddingToAdd - - // Send a SYN request. - iss := seqnum.Value(TestInitialSequenceNumber) - c.SendPacket(nil, &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagSyn, - SeqNum: iss, - RcvWnd: 30000, - TCPOpts: opts[:offset], - }) - - // Receive the SYN-ACK reply. Make sure MSS and other expected options - // are present. - b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */) - c.IRS = seqnum.Value(tcp.SequenceNumber()) - - tcpCheckers := []checker.TransportChecker{ - checker.SrcPort(StackPort), - checker.DstPort(TestPort), - checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), - checker.TCPAckNum(uint32(iss) + 1), - checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), - } - - // If TS option was enabled in the original SYN then add a checker to - // validate the Timestamp option in the SYN-ACK. - if synOptions.TS { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) - } else { - tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) - } - - checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) - rcvWnd := seqnum.Size(30000) - ackHeaders := &Headers{ - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagAck, - SeqNum: iss + 1, - AckNum: c.IRS + 1, - RcvWnd: rcvWnd, - } - - // If WS was expected to be in effect then scale the advertised window - // correspondingly. - if synOptions.WS > 0 { - ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) - } - - parsedOpts := tcp.ParsedOptions() - if synOptions.TS { - // Echo the tsVal back to the peer in the tsEcr field of the - // timestamp option. - // Increment TSVal by 1 from the value sent in the SYN and echo - // the TSVal in the SYN-ACK in the TSEcr field. - opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} - header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) - ackHeaders.TCPOpts = opts[:] - } - - // Send ACK. - c.SendPacket(nil, ackHeaders) - - c.RcvdWindowScale = uint8(rcvdSynOptions.WS) - c.Port = StackPort - - return &RawEndpoint{ - C: c, - SrcPort: TestPort, - DstPort: StackPort, - Flags: header.TCPFlagPsh | header.TCPFlagAck, - NextSeqNum: iss + 1, - AckNum: c.IRS + 1, - WndSize: rcvWnd, - SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), - RecentTS: parsedOpts.TSVal, - TSVal: synOptions.TSVal + 1, - } -} - -// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true -// for the Stack in the context. -func (c *Context) SACKEnabled() bool { - var v tcpip.TCPSACKEnabled - if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { - // Stack doesn't support SACK. So just return. - return false - } - return bool(v) -} - -// SetGSOEnabled enables or disables generic segmentation offload. -func (c *Context) SetGSOEnabled(enable bool) { - if enable { - c.linkEP.SupportedGSOKind = stack.HWGSOSupported - } else { - c.linkEP.SupportedGSOKind = stack.GSONotSupported - } -} - -// MSSWithoutOptions returns the value for the MSS used by the stack when no -// options are in use. -func (c *Context) MSSWithoutOptions() uint16 { - return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) -} - -// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no -// options are in use for IPv6 packets. -func (c *Context) MSSWithoutOptionsV6() uint16 { - return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) -} diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go deleted file mode 100644 index dbd6dff54..000000000 --- a/pkg/tcpip/transport/tcp/timer_test.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sleep" -) - -func TestCleanup(t *testing.T) { - const ( - timerDurationSeconds = 2 - isAssertedTimeoutSeconds = timerDurationSeconds + 1 - ) - - tmr := timer{} - w := sleep.Waker{} - tmr.init(&w) - tmr.enable(timerDurationSeconds * time.Second) - tmr.cleanup() - - if want := (timer{}); tmr != want { - t.Errorf("got tmr = %+v, want = %+v", tmr, want) - } - - // The waker should not be asserted. - for i := 0; i < isAssertedTimeoutSeconds; i++ { - time.Sleep(time.Second) - if w.IsAsserted() { - t.Fatalf("waker asserted unexpectedly") - } - } -} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD deleted file mode 100644 index 3ad6994a7..000000000 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "tcpconntrack", - srcs = ["tcp_conntrack.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( - name = "tcpconntrack_test", - size = "small", - srcs = ["tcp_conntrack_test.go"], - deps = [ - ":tcpconntrack", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go deleted file mode 100644 index 6c5ddc3c7..000000000 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpconntrack_test - -import ( - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" -) - -// connected creates a connection tracker TCB and sets it to a connected state -// by performing a 3-way handshake. -func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: iss, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: irw, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: irs, - AckNum: iss + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: isw, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: iss + 1, - AckNum: irs + 1, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: irw, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - return &tcb -} - -func TestConnectionRefused(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive RST. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionRefusedInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestConnectionResetInSynRcvd(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send RST with no ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagRst, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) - } -} - -func TestRetransmitOnSynSent(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Retransmit the same SYN. - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) - } -} - -func TestRetransmitOnSynRcvd(t *testing.T) { - // Send initial SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive SYN. This will cause the state to go to SYN-RCVD. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Retransmit the original SYN. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Transmit a SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} - -func TestClosedBySelf(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Send FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1236, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestClosedByPeer(t *testing.T) { - tcb := connected(t, 1234, 789, 30000, 50000) - - // Receive FIN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 790, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 791, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 791, - AckNum: 1236, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) - } -} - -func TestSendAndReceiveDataClosedBySelf(t *testing.T) { - sseq := uint32(1234) - rseq := uint32(789) - tcb := connected(t, sseq, rseq, 30000, 50000) - sseq++ - rseq++ - - // Send some data. - tcp := make(header.TCP, header.TCPMinimumSize+1024) - - for i := uint32(0); i < 10; i++ { - // Send some data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - sseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - for i := uint32(0); i < 10; i++ { - // Receive some data. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 50000, - }) - rseq += uint32(len(tcp)) - header.TCPMinimumSize - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ack for data. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - - // Send FIN. - tcp = tcp[:header.TCPMinimumSize] - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 30000, - }) - sseq++ - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Receive FIN/ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: rseq, - AckNum: sseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck | header.TCPFlagFin, - WindowSize: 50000, - }) - rseq++ - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: sseq, - AckNum: rseq, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) - } -} - -func TestIgnoreBadResetOnSynSent(t *testing.T) { - // Send SYN. - tcp := make(header.TCP, header.TCPMinimumSize) - tcp.Encode(&header.TCPFields{ - SeqNum: 1234, - AckNum: 0, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn, - WindowSize: 30000, - }) - - tcb := tcpconntrack.TCB{} - tcb.Init(tcp) - - // Receive a RST with a bad ACK, it should not cause the connection to - // be reset. - acks := []uint32{1234, 1236, 1000, 5000} - flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} - for _, a := range acks { - for _, f := range flags { - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: a, - DataOffset: header.TCPMinimumSize, - Flags: f, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - } - } - - // Complete the handshake. - // Receive SYN-ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 789, - AckNum: 1235, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagSyn | header.TCPFlagAck, - WindowSize: 50000, - }) - - if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } - - // Send ACK. - tcp.Encode(&header.TCPFields{ - SeqNum: 1235, - AckNum: 790, - DataOffset: header.TCPMinimumSize, - Flags: header.TCPFlagAck, - WindowSize: 30000, - }) - - if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { - t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) - } -} diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go new file mode 100644 index 000000000..ff53204da --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package tcpconntrack diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD deleted file mode 100644 index dd5c910ae..000000000 --- a/pkg/tcpip/transport/udp/BUILD +++ /dev/null @@ -1,64 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "udp_packet_list", - out = "udp_packet_list.go", - package = "udp", - prefix = "udpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*udpPacket", - "Linker": "*udpPacket", - }, -) - -go_library( - name = "udp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "udp_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - ], -) - -go_test( - name = "udp_x_test", - size = "small", - srcs = ["udp_test.go"], - deps = [ - ":udp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100644 index 000000000..c396f77c9 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,221 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +// +//go:nosplit +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +// +//go:nosplit +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +// +//go:nosplit +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +// +//go:nosplit +func (l *udpPacketList) Len() (count int) { + for e := l.Front(); e != nil; e = (udpPacketElementMapper{}.linkerFor(e)).Next() { + count++ + } + return count +} + +// PushFront inserts the element e at the front of list l. +// +//go:nosplit +func (l *udpPacketList) PushFront(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +// +//go:nosplit +func (l *udpPacketList) PushBack(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +// +//go:nosplit +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +// +//go:nosplit +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + bLinker := udpPacketElementMapper{}.linkerFor(b) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +// +//go:nosplit +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + aLinker := udpPacketElementMapper{}.linkerFor(a) + eLinker := udpPacketElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +// +//go:nosplit +func (l *udpPacketList) Remove(e *udpPacket) { + linker := udpPacketElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else if l.head == e { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else if l.tail == e { + l.tail = prev + } + + linker.SetNext(nil) + linker.SetPrev(nil) +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +// +//go:nosplit +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +// +//go:nosplit +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +// +//go:nosplit +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +// +//go:nosplit +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100644 index 000000000..092daa0b8 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,238 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (u *udpPacket) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacket" +} + +func (u *udpPacket) StateFields() []string { + return []string{ + "udpPacketEntry", + "senderAddress", + "destinationAddress", + "packetInfo", + "data", + "timestamp", + "tos", + } +} + +func (u *udpPacket) beforeSave() {} + +// +checklocksignore +func (u *udpPacket) StateSave(stateSinkObject state.Sink) { + u.beforeSave() + var dataValue buffer.VectorisedView = u.saveData() + stateSinkObject.SaveValue(4, dataValue) + stateSinkObject.Save(0, &u.udpPacketEntry) + stateSinkObject.Save(1, &u.senderAddress) + stateSinkObject.Save(2, &u.destinationAddress) + stateSinkObject.Save(3, &u.packetInfo) + stateSinkObject.Save(5, &u.timestamp) + stateSinkObject.Save(6, &u.tos) +} + +func (u *udpPacket) afterLoad() {} + +// +checklocksignore +func (u *udpPacket) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &u.udpPacketEntry) + stateSourceObject.Load(1, &u.senderAddress) + stateSourceObject.Load(2, &u.destinationAddress) + stateSourceObject.Load(3, &u.packetInfo) + stateSourceObject.Load(5, &u.timestamp) + stateSourceObject.Load(6, &u.tos) + stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { u.loadData(y.(buffer.VectorisedView)) }) +} + +func (e *endpoint) StateTypeName() string { + return "pkg/tcpip/transport/udp.endpoint" +} + +func (e *endpoint) StateFields() []string { + return []string{ + "TransportEndpointInfo", + "DefaultSocketOptionsHandler", + "waiterQueue", + "uniqueID", + "rcvReady", + "rcvList", + "rcvBufSize", + "rcvClosed", + "state", + "dstPort", + "ttl", + "multicastTTL", + "multicastAddr", + "multicastNICID", + "portFlags", + "lastError", + "boundBindToDevice", + "boundPortFlags", + "sendTOS", + "shutdownFlags", + "multicastMemberships", + "effectiveNetProtos", + "owner", + "ops", + "frozen", + } +} + +// +checklocksignore +func (e *endpoint) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.TransportEndpointInfo) + stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler) + stateSinkObject.Save(2, &e.waiterQueue) + stateSinkObject.Save(3, &e.uniqueID) + stateSinkObject.Save(4, &e.rcvReady) + stateSinkObject.Save(5, &e.rcvList) + stateSinkObject.Save(6, &e.rcvBufSize) + stateSinkObject.Save(7, &e.rcvClosed) + stateSinkObject.Save(8, &e.state) + stateSinkObject.Save(9, &e.dstPort) + stateSinkObject.Save(10, &e.ttl) + stateSinkObject.Save(11, &e.multicastTTL) + stateSinkObject.Save(12, &e.multicastAddr) + stateSinkObject.Save(13, &e.multicastNICID) + stateSinkObject.Save(14, &e.portFlags) + stateSinkObject.Save(15, &e.lastError) + stateSinkObject.Save(16, &e.boundBindToDevice) + stateSinkObject.Save(17, &e.boundPortFlags) + stateSinkObject.Save(18, &e.sendTOS) + stateSinkObject.Save(19, &e.shutdownFlags) + stateSinkObject.Save(20, &e.multicastMemberships) + stateSinkObject.Save(21, &e.effectiveNetProtos) + stateSinkObject.Save(22, &e.owner) + stateSinkObject.Save(23, &e.ops) + stateSinkObject.Save(24, &e.frozen) +} + +// +checklocksignore +func (e *endpoint) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.TransportEndpointInfo) + stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler) + stateSourceObject.Load(2, &e.waiterQueue) + stateSourceObject.Load(3, &e.uniqueID) + stateSourceObject.Load(4, &e.rcvReady) + stateSourceObject.Load(5, &e.rcvList) + stateSourceObject.Load(6, &e.rcvBufSize) + stateSourceObject.Load(7, &e.rcvClosed) + stateSourceObject.Load(8, &e.state) + stateSourceObject.Load(9, &e.dstPort) + stateSourceObject.Load(10, &e.ttl) + stateSourceObject.Load(11, &e.multicastTTL) + stateSourceObject.Load(12, &e.multicastAddr) + stateSourceObject.Load(13, &e.multicastNICID) + stateSourceObject.Load(14, &e.portFlags) + stateSourceObject.Load(15, &e.lastError) + stateSourceObject.Load(16, &e.boundBindToDevice) + stateSourceObject.Load(17, &e.boundPortFlags) + stateSourceObject.Load(18, &e.sendTOS) + stateSourceObject.Load(19, &e.shutdownFlags) + stateSourceObject.Load(20, &e.multicastMemberships) + stateSourceObject.Load(21, &e.effectiveNetProtos) + stateSourceObject.Load(22, &e.owner) + stateSourceObject.Load(23, &e.ops) + stateSourceObject.Load(24, &e.frozen) + stateSourceObject.AfterLoad(e.afterLoad) +} + +func (m *multicastMembership) StateTypeName() string { + return "pkg/tcpip/transport/udp.multicastMembership" +} + +func (m *multicastMembership) StateFields() []string { + return []string{ + "nicID", + "multicastAddr", + } +} + +func (m *multicastMembership) beforeSave() {} + +// +checklocksignore +func (m *multicastMembership) StateSave(stateSinkObject state.Sink) { + m.beforeSave() + stateSinkObject.Save(0, &m.nicID) + stateSinkObject.Save(1, &m.multicastAddr) +} + +func (m *multicastMembership) afterLoad() {} + +// +checklocksignore +func (m *multicastMembership) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &m.nicID) + stateSourceObject.Load(1, &m.multicastAddr) +} + +func (l *udpPacketList) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacketList" +} + +func (l *udpPacketList) StateFields() []string { + return []string{ + "head", + "tail", + } +} + +func (l *udpPacketList) beforeSave() {} + +// +checklocksignore +func (l *udpPacketList) StateSave(stateSinkObject state.Sink) { + l.beforeSave() + stateSinkObject.Save(0, &l.head) + stateSinkObject.Save(1, &l.tail) +} + +func (l *udpPacketList) afterLoad() {} + +// +checklocksignore +func (l *udpPacketList) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &l.head) + stateSourceObject.Load(1, &l.tail) +} + +func (e *udpPacketEntry) StateTypeName() string { + return "pkg/tcpip/transport/udp.udpPacketEntry" +} + +func (e *udpPacketEntry) StateFields() []string { + return []string{ + "next", + "prev", + } +} + +func (e *udpPacketEntry) beforeSave() {} + +// +checklocksignore +func (e *udpPacketEntry) StateSave(stateSinkObject state.Sink) { + e.beforeSave() + stateSinkObject.Save(0, &e.next) + stateSinkObject.Save(1, &e.prev) +} + +func (e *udpPacketEntry) afterLoad() {} + +// +checklocksignore +func (e *udpPacketEntry) StateLoad(stateSourceObject state.Source) { + stateSourceObject.Load(0, &e.next) + stateSourceObject.Load(1, &e.prev) +} + +func init() { + state.Register((*udpPacket)(nil)) + state.Register((*endpoint)(nil)) + state.Register((*multicastMembership)(nil)) + state.Register((*udpPacketList)(nil)) + state.Register((*udpPacketEntry)(nil)) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go deleted file mode 100644 index 2e283e52b..000000000 --- a/pkg/tcpip/transport/udp/udp_test.go +++ /dev/null @@ -1,2568 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_test - -import ( - "bytes" - "context" - "fmt" - "io/ioutil" - "math/rand" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// Addresses and ports used for testing. It is recommended that tests stick to -// using these addresses as it allows using the testFlow helper. -// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' -// represents the remote endpoint. -const ( - v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = v4MappedAddrPrefix + stackAddr - testV4MappedAddr = v4MappedAddrPrefix + testAddr - multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr - broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr - v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 - invalidPort = 8192 - multicastAddr = "\xe8\x2b\xd3\xea" - multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - broadcastAddr = header.IPv4Broadcast - testTOS = 0x80 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in -// a packet header. These values are used to populate a header or verify one. -// Note that because they are used in packet headers, the addresses are never in -// a V4-mapped format. -type header4Tuple struct { - srcAddr tcpip.FullAddress - dstAddr tcpip.FullAddress -} - -// testFlow implements a helper type used for sending and receiving test -// packets. A given test flow value defines 1) the socket endpoint used for the -// test and 2) the type of packet send or received on the endpoint. E.g., a -// multicastV6Only flow is a V6 multicast packet passing through a V6-only -// endpoint. The type provides helper methods to characterize the flow (e.g., -// isV4) as well as return a proper header4Tuple for it. -type testFlow int - -const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket - reverseMulticast4 // V4 multicast src. Must fail. - reverseMulticast6 // V6 multicast src. Must fail. -) - -func (flow testFlow) String() string { - switch flow { - case unicastV4: - return "unicastV4" - case unicastV6: - return "unicastV6" - case unicastV6Only: - return "unicastV6Only" - case unicastV4in6: - return "unicastV4in6" - case multicastV4: - return "multicastV4" - case multicastV6: - return "multicastV6" - case multicastV6Only: - return "multicastV6Only" - case multicastV4in6: - return "multicastV4in6" - case broadcast: - return "broadcast" - case broadcastIn6: - return "broadcastIn6" - case reverseMulticast4: - return "reverseMulticast4" - case reverseMulticast6: - return "reverseMulticast6" - default: - return "unknown" - } -} - -// packetDirection explains if a flow is incoming (read) or outgoing (write). -type packetDirection int - -const ( - incoming packetDirection = iota - outgoing -) - -// header4Tuple returns the header4Tuple for the given flow and direction. Note -// that the tuple contains no mapped addresses as those only exist at the socket -// level but not at the packet header level. -func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { - var h header4Tuple - if flow.isV4() { - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastAddr - } else if flow.isBroadcast() { - h.dstAddr.Addr = broadcastAddr - } - } else { // IPv6 - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastV6Addr - } - } - if flow.isReverseMulticast() { - h.srcAddr.Addr = flow.getMcastAddr() - } - return h -} - -func (flow testFlow) getMcastAddr() tcpip.Address { - if flow.isV4() { - return multicastAddr - } - return multicastV6Addr -} - -// mapAddrIfApplicable converts the given V4 address into its V4-mapped version -// if it is applicable to the flow. -func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { - if flow.isMapped() { - return v4MappedAddrPrefix + v4Addr - } - return v4Addr -} - -// netProto returns the protocol number used for the network packet. -func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { - if flow.isV4() { - return ipv4.ProtocolNumber - } - return ipv6.ProtocolNumber -} - -// sockProto returns the protocol number used when creating the socket -// endpoint for this flow. -func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { - switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: - return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast, reverseMulticast4: - return ipv4.ProtocolNumber - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { - if flow.isV4() { - return checker.IPv4 - } - return checker.IPv6 -} - -func (flow testFlow) isV6() bool { return !flow.isV4() } -func (flow testFlow) isV4() bool { - return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() -} - -func (flow testFlow) isV6Only() bool { - switch flow { - case unicastV6Only, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMulticast() bool { - switch flow { - case multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isBroadcast() bool { - switch flow { - case broadcast, broadcastIn6: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMapped() bool { - switch flow { - case unicastV4in6, multicastV4in6, broadcastIn6: - return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isReverseMulticast() bool { - switch flow { - case reverseMulticast4, reverseMulticast6: - return true - default: - return false - } -} - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func newDualTestContext(t *testing.T, mtu uint32) *testContext { - t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: true, - }) -} - -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { - t.Helper() - - s := stack.New(options) - ep := channel.New(256, mtu, "") - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %s", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %s", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %s", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { - c.t.Helper() - - var err tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) - if err != nil { - c.t.Fatal("NewEndpoint failed: ", err) - } -} - -func (c *testContext) createEndpointForFlow(flow testFlow) { - c.t.Helper() - - c.createEndpoint(flow.sockProto()) - if flow.isV6Only() { - c.ep.SocketOptions().SetV6Only(true) - } else if flow.isBroadcast() { - c.ep.SocketOptions().SetBroadcast(true) - } -} - -// getPacketAndVerify reads a packet from the link endpoint and verifies the -// header against expected values from the given test flow. In addition, it -// calls any extra checker functions provided. -func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { - c.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want { - c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want) - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - b := vv.ToView() - - h := flow.header4Tuple(outgoing) - checkers = append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b -} - -// injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. If badChecksum is true, the packet has -// a bad checksum in the UDP header. -func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) { - c.t.Helper() - - h := flow.header4Tuple(incoming) - if flow.isV4() { - buf := c.buildV4Packet(payload, &h) - if badChecksum { - // Invalidate the UDP header checksum field, taking care to avoid - // overflow to zero, which would disable checksum validation. - for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { - u.SetChecksum(u.Checksum() + 1) - if u.Checksum() != 0 { - break - } - } - } - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } else { - buf := c.buildV6Packet(payload, &h) - if badChecksum { - // Invalidate the UDP header checksum field (Unlike IPv4, zero is - // a valid checksum value for IPv6 so no need to avoid it). - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) - } - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } -} - -// buildV6Packet creates a V6 test packet with the given payload and header -// values in a buffer. -func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - return buf -} - -// buildV4Packet creates a V4 test packet with the given payload and header -// values in a buffer. -func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - TOS: testTOS, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - return buf -} - -func newPayload() []byte { - return newMinPayload(30) -} - -func newMinPayload(minSize int) []byte { - b := make([]byte, minSize+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError tcpip.Error - getBindToDevice int32 - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := int32(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) - } - } - bindToDevice := ep.SocketOptions().GetBindToDevice() - if bindToDevice != testAction.getBindToDevice { - t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) - } - }) - } -} - -// testReadInternal sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then attempts to read it from the -// UDP endpoint and depending on if this was expected to succeed verifies its -// correctness including any additional checker functions provided. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - - payload := newPayload() - c.injectPacket(flow, payload, false) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.ReadableEvents) - defer c.wq.EventUnregister(&we) - - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - var buf bytes.Buffer - res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Wait for data to become available. - select { - case <-ch: - res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - - case <-time.After(300 * time.Millisecond): - if packetShouldBeDropped { - return // expected to time out - } - c.t.Fatal("timed out waiting for data") - } - } - - if expectReadError && err != nil { - c.checkEndpointReadStats(1, epstats, err) - return - } - - if err != nil { - c.t.Fatal("Read failed:", err) - } - - if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr) - } - - // Check the read result. - h := flow.header4Tuple(incoming) - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", // ControlMessages will be checked later. - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff) - } - - // Check the payload. - v := buf.Bytes() - if !bytes.Equal(payload, v) { - c.t.Fatalf("got payload = %x, want = %x", v, payload) - } - - // Run any checkers against the ControlMessages. - for _, f := range checkers { - f(c.t, res.ControlMessages) - } - - c.checkEndpointReadStats(1, epstats, err) -} - -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness including any additional checker functions provided. -func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) -} - -// testFailingRead sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then tries to read it from the UDP -// endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { - c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) -} - -func TestBindEphemeralPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } -} - -func TestBindReservedPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - addr, err := c.ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %s", err) - } - - // We can't bind the address reserved by the connected endpoint above. - { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - { - err := ep.Bind(addr) - if _, ok := err.(*tcpip.ErrPortInUse); !ok { - t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) - } - } - } - - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - // We can't bind ipv4-any on the port reserved by the connected endpoint - // above, since the endpoint is dual-stack. - { - err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) - if _, ok := err.(*tcpip.ErrPortInUse); !ok { - t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) - } - } - // We can bind an ipv4 address on this port, though. - if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } - }() - - // Once the connected endpoint releases its port reservation, we are able to - // bind ipv4-any once again. - c.ep.Close() - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %s", err) - } - defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %s", err) - } - }() -} - -func TestV4ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to local address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV6ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV6) -} - -// TestV4ReadSelfSource checks that packets coming from a local IP address are -// correctly dropped when handleLocal is true and not otherwise. -func TestV4ReadSelfSource(t *testing.T) { - for _, tt := range []struct { - name string - handleLocal bool - wantErr tcpip.Error - wantInvalidSource uint64 - }{ - {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, - } { - t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: tt.handleLocal, - }) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - h.srcAddr = h.dstAddr - - buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { - t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) - } - - if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr { - t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr) - } - }) - } -} - -func TestV4ReadOnV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Test acceptance. - testRead(c, unicastV4) -} - -// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast -// address and receive data sent to that address. -func TestReadOnBoundToMulticast(t *testing.T) { - // FIXME(b/128189410): multicastV4in6 currently doesn't work as - // AddMembershipOption doesn't handle V4in6 addresses. - for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to multicast address. - mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) - if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - // Join multicast group. - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) - } - - // Check that we receive multicast packets but not unicast or broadcast - // ones. - testRead(c, flow) - testFailingRead(c, broadcast, false /* expectReadError */) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and can receive only broadcast data. -func TestV4ReadOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to broadcast address. - bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) - if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Check that we receive broadcast packets but not unicast ones. - testRead(c, flow) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestReadFromMulticast checks that an endpoint will NOT receive a packet -// that was sent with multicast SOURCE address. -func TestReadFromMulticast(t *testing.T) { - for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - testFailingRead(c, flow, false /* expectReadError */) - }) - } -} - -// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY -// and receive broadcast and unicast data. -func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s (", err) - } - - // Check that we receive both broadcast and unicast packets. - testRead(c, flow) - testRead(c, unicastV4) - }) - } -} - -// testFailingWrite sends a packet of the given test flow into the UDP endpoint -// and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - - var r bytes.Reader - r.Reset(newPayload()) - _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - }) - c.checkEndpointWriteStats(1, epstats, gotErr) - if gotErr != wantErr { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) - } -} - -// testWrite sends a packet of the given test flow from the UDP endpoint to the -// flow's destination address:port. It then receives it from the link endpoint -// and verifies its correctness including any additional checker functions -// provided. -func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteAndVerifyInternal(c, flow, true, checkers...) -} - -// testWriteWithoutDestination sends a packet of the given test flow from the -// UDP endpoint without giving a destination address:port. It then receives it -// from the link endpoint and verifies its correctness including any additional -// checker functions provided. -func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteAndVerifyInternal(c, flow, false, checkers...) -} - -func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - writeOpts := tcpip.WriteOptions{} - if setDest { - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - writeOpts = tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - } - } - var r bytes.Reader - payload := newPayload() - r.Reset(payload) - n, err := c.ep.Write(&r, writeOpts) - if err != nil { - c.t.Fatalf("Write failed: %s", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - c.checkEndpointWriteStats(1, epstats, err) - return payload -} - -func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - payload := testWriteNoVerify(c, flow, setDest) - // Received the packet and check the payload. - b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP - if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) - } else { - udp = header.UDP(header.IPv6(b).Payload()) - } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } - - return udp.SourcePort() -} - -func testDualWrite(c *testContext) uint16 { - c.t.Helper() - - v4Port := testWrite(c, unicastV4in6) - v6Port := testWrite(c, unicastV6) - if v4Port != v6Port { - c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) - } - - return v4Port -} - -func TestDualWriteUnbound(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) -} - -func TestDualWriteBoundToWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - p := testDualWrite(c) - if p != stackPort { - c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) - } -} - -func TestDualWriteConnectedToV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, unicastV6) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) - const want = 1 - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { - c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) - } -} - -func TestDualWriteConnectedToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, unicastV4in6) - - // Write to v6 address. - testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) -} - -func TestV4WriteOnV6Only(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6Only) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) -} - -func TestV6WriteOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to v4 mapped address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Write to v6 address. - testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) -} - -func TestV6WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - testWriteWithoutDestination(c, unicastV6) -} - -func TestV4WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - testWriteWithoutDestination(c, unicastV4) -} - -func TestWriteOnConnectedInvalidPort(t *testing.T) { - protocols := map[string]tcpip.NetworkProtocolNumber{ - "ipv4": ipv4.ProtocolNumber, - "ipv6": ipv6.ProtocolNumber, - } - for name, pn := range protocols { - t.Run(name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(pn) - if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - writeOpts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, - } - var r bytes.Reader - payload := newPayload() - r.Reset(payload) - n, err := c.ep.Write(&r, writeOpts) - if err != nil { - c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) - } - if got, want := n, int64(len(payload)); got != want { - c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) - } - - { - err := c.ep.LastError() - if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { - c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) - } - } - }) - } -} - -// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket -// that is bound to a V4 multicast address. -func TestWriteOnBoundToV4Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a -// socket that is bound to a V4-mapped multicast address. -func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV6, multicastV6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// V6-only socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToBroadcast checks that we can send packets out of a -// socket that is bound to the broadcast address. -func TestWriteOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 broadcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a -// socket that is bound to the V4-mapped broadcast address. -func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -func TestReadIncrementsPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create IPv4 UDP endpoint - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testRead(c, unicastV4) - - var want uint64 = 1 - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { - c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) - } -} - -func TestReadIPPacketInfo(t *testing.T) { - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedLocalAddr tcpip.Address - expectedDestAddr tcpip.Address - }{ - { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedLocalAddr: stackAddr, - expectedDestAddr: stackAddr, - }, - { - name: "IPv4 multicast", - proto: header.IPv4ProtocolNumber, - flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastAddr, - expectedDestAddr: multicastAddr, - }, - { - name: "IPv4 broadcast", - proto: header.IPv4ProtocolNumber, - flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: broadcastAddr, - expectedDestAddr: broadcastAddr, - }, - { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedLocalAddr: stackV6Addr, - expectedDestAddr: stackV6Addr, - }, - { - name: "IPv6 multicast", - proto: header.IPv6ProtocolNumber, - flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedLocalAddr: multicastV6Addr, - expectedDestAddr: multicastV6Addr, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(test.proto) - - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%+v): %s", bindAddr, err) - } - - if test.flow.isMulticast() { - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) - } - } - - c.ep.SocketOptions().SetReceivePacketInfo(true) - - testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ - NIC: 1, - LocalAddr: test.expectedLocalAddr, - DestinationAddr: test.expectedDestAddr, - })) - - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { - t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) - } - }) - } -} - -func TestReadRecvOriginalDstAddr(t *testing.T) { - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - flow testFlow - expectedOriginalDstAddr tcpip.FullAddress - }{ - { - name: "IPv4 unicast", - proto: header.IPv4ProtocolNumber, - flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, - }, - { - name: "IPv4 multicast", - proto: header.IPv4ProtocolNumber, - flow: multicastV4, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, - }, - { - name: "IPv4 broadcast", - proto: header.IPv4ProtocolNumber, - flow: broadcast, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, - }, - { - name: "IPv6 unicast", - proto: header.IPv6ProtocolNumber, - flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, - }, - { - name: "IPv6 multicast", - proto: header.IPv6ProtocolNumber, - flow: multicastV6, - // This should actually be a unicast address assigned to the interface. - // - // TODO(gvisor.dev/issue/3556): This check is validating incorrect - // behaviour. We still include the test so that once the bug is - // resolved, this test will start to fail and the individual tasked - // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(test.proto) - - bindAddr := tcpip.FullAddress{Port: stackPort} - if err := c.ep.Bind(bindAddr); err != nil { - t.Fatalf("Bind(%#v): %s", bindAddr, err) - } - - if test.flow.isMulticast() { - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) - } - } - - c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) - - testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) - - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { - t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) - } - }) - } -} - -func TestWriteIncrementsPacketsSent(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) - - var want uint64 = 2 - if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { - c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) - } -} - -func TestNoChecksum(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Disable the checksum generation. - c.ep.SocketOptions().SetNoChecksum(true) - // This option is effective on IPv4 only. - testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) - - // Enable the checksum generation. - c.ep.SocketOptions().SetNoChecksum(false) - testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) - }) - } -} - -var _ stack.NetworkInterface = (*testInterface)(nil) - -type testInterface struct { - stack.NetworkInterface -} - -func (*testInterface) ID() tcpip.NICID { - return 0 -} - -func (*testInterface) Enabled() bool { - return true -} - -func TestTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const multicastTTL = 42 - if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { - c.t.Fatalf("SetSockOptInt failed: %s", err) - } - - var wantTTL uint8 - if flow.isMulticast() { - wantTTL = multicastTTL - } else { - var p stack.NetworkProtocolFactory - var n tcpip.NetworkProtocolNumber - if flow.isV4() { - p = ipv4.NewProtocol - n = ipv4.ProtocolNumber - } else { - p = ipv6.NewProtocol - n = ipv6.ProtocolNumber - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{p}, - }) - ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) - wantTTL = ep.DefaultTTL() - ep.Close() - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } -} - -func TestSetTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { - c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } - }) - } -} - -func TestSetTOS(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = testTOS - v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) - } - - if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { - c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) - } - - v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) - } - - if v != tos { - c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestSetTClass(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tClass = testTOS - v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) - } - - if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { - c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) - } - - v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) - if err != nil { - c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) - } - - if v != tClass { - c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) - } - - // The header getter for TClass is called TOS, so use that checker. - testWrite(c, flow, checker.TOS(tClass, 0)) - }) - } -} - -func TestReceiveTosTClass(t *testing.T) { - const RcvTOSOpt = "ReceiveTosOption" - const RcvTClassOpt = "ReceiveTClassOption" - - testCases := []struct { - name string - tests []testFlow - }{ - {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, - {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, - } - for _, testCase := range testCases { - for _, flow := range testCase.tests { - t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - name := testCase.name - - var optionGetter func() bool - var optionSetter func(bool) - switch name { - case RcvTOSOpt: - optionGetter = c.ep.SocketOptions().GetReceiveTOS - optionSetter = c.ep.SocketOptions().SetReceiveTOS - case RcvTClassOpt: - optionGetter = c.ep.SocketOptions().GetReceiveTClass - optionSetter = c.ep.SocketOptions().SetReceiveTClass - default: - t.Fatalf("unkown test variant: %s", name) - } - - // Verify that setting and reading the option works. - v := optionGetter() - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) - } - - want := true - optionSetter(want) - - got := optionGetter() - if got != want { - c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) - } - - // Verify that the correct received TOS or TClass is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - switch name { - case RcvTClassOpt: - testRead(c, flow, checker.ReceiveTClass(testTOS)) - case RcvTOSOpt: - testRead(c, flow, checker.ReceiveTOS(testTOS)) - default: - t.Fatalf("unknown test variant: %s", name) - } - }) - } - } -} - -func TestMulticastInterfaceOption(t *testing.T) { - for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - h := flow.header4Tuple(outgoing) - mcastAddr := h.dstAddr.Addr - localIfAddr := h.srcAddr.Addr - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: flow.mapAddrIfApplicable(mcastAddr), - Port: stackPort, - } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - } - - if err := c.ep.SetSockOpt(&ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err) - } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant { - c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) - } -} - -// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV4UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true, will result in a payload large enough - // so that the final generated IPv4 packet is larger than - // header.IPv4MinimumProcessableDatagramSize. - largePayload bool - // badChecksum if true, will set an invalid checksum in the - // header. - badChecksum bool - }{ - {unicastV4, true, false, false}, - {unicastV4, true, true, false}, - {unicastV4, false, false, true}, - {unicastV4, false, true, true}, - {multicastV4, false, false, false}, - {multicastV4, false, true, false}, - {broadcast, false, false, false}, - {broadcast, false, true, false}, - } - checksumErrors := uint64(0) - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(576) - } - c.injectPacket(tc.flow, payload, tc.badChecksum) - if tc.badChecksum { - checksumErrors++ - if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { - t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - } - if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - - // We need to compare the included data part of the UDP packet that is in - // the ICMP packet with the matching original data. - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantLen := len(payload) - if tc.largePayload { - // To work out the data size we need to simulate what the sender would - // have done. The wanted size is the total available minus the sum of - // the headers in the UDP AND ICMP packets, given that we know the test - // had only a minimal IP header but the ICMP sender will have allowed - // for a maximally sized packet header. - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - } - - // In the case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - }) - } -} - -// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV6UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true will result in a payload large enough to - // create an IPv6 packet > header.IPv6MinimumMTU bytes. - largePayload bool - // badChecksum if true, will set an invalid checksum in the - // header. - badChecksum bool - }{ - {unicastV6, true, false, false}, - {unicastV6, true, true, false}, - {unicastV6, false, false, true}, - {unicastV6, false, true, true}, - {multicastV6, false, false, false}, - {multicastV6, false, true, false}, - } - checksumErrors := uint64(0) - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(1280) - } - c.injectPacket(tc.flow, payload, tc.badChecksum) - if tc.badChecksum { - checksumErrors++ - if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { - t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - } - if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) - pkt := vv.ToView() - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - }) - } -} - -// TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats are incremented. -func TestIncrementMalformedPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - - // Invalidate the UDP header length field. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetLength(u.Length() + 1) - - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } -} - -// TestShortHeader verifies that when a packet with a too-short UDP header is -// received, the malformed received global stat gets incremented. -func TestShortHeader(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - h := unicastV6.header4Tuple(incoming) - - // Allocate a buffer for an IPv6 and too-short UDP header. - const udpSize = header.UDPMinimumSize - 1 - buf := buffer.NewView(header.IPv6MinimumSize + udpSize) - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - TransportProtocol: udp.ProtocolNumber, - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) - udpHdr.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: header.UDPMinimumSize, - }) - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) - // Copy all but the last byte of the UDP header into the packet. - copy(buf[header.IPv6MinimumSize:], udpHdr) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) - } -} - -// TestBadChecksumErrors verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestBadChecksumErrors(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV6} { - t.Run(flow.String(), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } - }) - } -} - -// TestPayloadModifiedV4 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestPayloadModifiedV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be - // incorrect. - buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestPayloadModifiedV6 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestPayloadModifiedV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be - // incorrect. - buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestChecksumZeroV4 verifies if the checksum value is zero, global and -// endpoint states are *not* incremented (UDP checksum is optional on IPv4). -func TestChecksumZeroV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) - // Set the checksum field in the UDP header to zero. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 0 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestChecksumZeroV6 verifies if the checksum value is zero, global and -// endpoint states are incremented (UDP checksum is *not* optional on IPv6). -func TestChecksumZeroV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Set the checksum field in the UDP header to zero. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestShutdownRead verifies endpoint read shutdown and error -// stats increment on packet receive. -func TestShutdownRead(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - testFailingRead(c, unicastV6, true /* expectReadError */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { - t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) - } -} - -// TestShutdownWrite verifies endpoint write shutdown and error -// stats increment on packet write. -func TestShutdownWrite(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %s", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %s", err) - } - - testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) -} - -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err.(type) { - case nil: - want.PacketsSent.IncrementBy(incr) - case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: - want.WriteErrors.InvalidArgs.IncrementBy(incr) - case *tcpip.ErrClosedForSend: - want.WriteErrors.WriteClosed.IncrementBy(incr) - case *tcpip.ErrInvalidEndpointState: - want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: - want.SendErrors.NoRoute.IncrementBy(incr) - default: - want.SendErrors.SendToNetworkFailed.IncrementBy(incr) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - case *tcpip.ErrClosedForReceive: - want.ReadErrors.ReadClosed.IncrementBy(incr) - default: - c.t.Errorf("Endpoint error missing stats update err %v", err) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func TestOutgoingSubnetBroadcast(t *testing.T) { - const nicID1 = 1 - - ipv4Addr := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 24, - } - ipv4Subnet := ipv4Addr.Subnet() - ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := testutil.MustParse4("192.168.1.1") - ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 31, - } - ipv4Subnet31 := ipv4AddrPrefix31.Subnet() - ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() - ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 32, - } - ipv4Subnet32 := ipv4AddrPrefix32.Subnet() - ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() - ipv6Addr := tcpip.AddressWithPrefix{ - Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - PrefixLen: 64, - } - ipv6Subnet := ipv6Addr.Subnet() - ipv6SubnetBcast := ipv6Subnet.Broadcast() - remNetAddr := tcpip.AddressWithPrefix{ - Address: "\x64\x0a\x7b\x18", - PrefixLen: 24, - } - remNetSubnet := remNetAddr.Subnet() - remNetSubnetBcast := remNetSubnet.Broadcast() - - tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - requiresBroadcastOpt bool - }{ - { - name: "IPv4 Broadcast to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv4SubnetBcast, - requiresBroadcastOpt: true, - }, - { - name: "IPv4 Broadcast to local /31 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix31, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet31, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet31Bcast, - requiresBroadcastOpt: false, - }, - { - name: "IPv4 Broadcast to local /32 subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4AddrPrefix32, - }, - routes: []tcpip.Route{ - { - Destination: ipv4Subnet32, - NIC: nicID1, - }, - }, - remoteAddr: ipv4Subnet32Bcast, - requiresBroadcastOpt: false, - }, - // IPv6 has no notion of a broadcast. - { - name: "IPv6 'Broadcast' to local subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: ipv6Addr, - }, - routes: []tcpip.Route{ - { - Destination: ipv6Subnet, - NIC: nicID1, - }, - }, - remoteAddr: ipv6SubnetBcast, - requiresBroadcastOpt: false, - }, - { - name: "IPv4 Broadcast to remote subnet", - nicAddr: tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: ipv4Addr, - }, - routes: []tcpip.Route{ - { - Destination: remNetSubnet, - Gateway: ipv4Gateway, - NIC: nicID1, - }, - }, - remoteAddr: remNetSubnetBcast, - // TODO(gvisor.dev/issue/3938): Once we support marking a route as - // broadcast, this test should require the broadcast option to be set. - requiresBroadcastOpt: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) - } - - s.SetRouteTable(test.routes) - - var netProto tcpip.NetworkProtocolNumber - switch l := len(test.remoteAddr); l { - case header.IPv4AddressSize: - netProto = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - netProto = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) - } - defer ep.Close() - - var r bytes.Reader - data := []byte{1, 2, 3, 4} - to := tcpip.FullAddress{ - Addr: test.remoteAddr, - Port: 80, - } - opts := tcpip.WriteOptions{To: &to} - expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { - if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { - return nil - } - return &tcpip.ErrBroadcastDisabled{} - } - if !test.requiresBroadcastOpt { - expectedErrWithoutBcastOpt = nil - } - - r.Reset(data) - { - n, err := ep.Write(&r, opts) - if expectedErrWithoutBcastOpt != nil { - if want := expectedErrWithoutBcastOpt(err); want != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) - } - } else if err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - } - - ep.SocketOptions().SetBroadcast(true) - - r.Reset(data) - if n, err := ep.Write(&r, opts); err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - - ep.SocketOptions().SetBroadcast(false) - - r.Reset(data) - { - n, err := ep.Write(&r, opts) - if expectedErrWithoutBcastOpt != nil { - if want := expectedErrWithoutBcastOpt(err); want != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) - } - } else if err != nil { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) - } - } - }) - } -} |