diff options
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r-- | pkg/tcpip/transport/udp/BUILD | 61 | ||||
-rwxr-xr-x | pkg/tcpip/transport/udp/udp_packet_list.go | 173 | ||||
-rwxr-xr-x | pkg/tcpip/transport/udp/udp_state_autogen.go | 138 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 1780 |
4 files changed, 311 insertions, 1841 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD deleted file mode 100644 index adc908e24..000000000 --- a/pkg/tcpip/transport/udp/BUILD +++ /dev/null @@ -1,61 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "udp_packet_list", - out = "udp_packet_list.go", - package = "udp", - prefix = "udpPacket", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*udpPacket", - "Linker": "*udpPacket", - }, -) - -go_library( - name = "udp", - srcs = [ - "endpoint.go", - "endpoint_state.go", - "forwarder.go", - "protocol.go", - "udp_packet_list.go", - ], - imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/sleep", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/iptables", - "//pkg/tcpip/ports", - "//pkg/tcpip/stack", - "//pkg/tcpip/transport/raw", - "//pkg/waiter", - ], -) - -go_test( - name = "udp_x_test", - size = "small", - srcs = ["udp_test.go"], - deps = [ - ":udp", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/waiter", - ], -) diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100755 index 000000000..673a9373b --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,173 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *udpPacketList) PushFront(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + udpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *udpPacketList) PushBack(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(nil) + udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + a := udpPacketElementMapper{}.linkerFor(b).Next() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + b := udpPacketElementMapper{}.linkerFor(a).Prev() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *udpPacketList) Remove(e *udpPacket) { + prev := udpPacketElementMapper{}.linkerFor(e).Prev() + next := udpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100755 index 000000000..75b78fa01 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,138 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func (x *udpPacket) beforeSave() {} +func (x *udpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("udpPacketEntry", &x.udpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) + m.Save("tos", &x.tos) +} + +func (x *udpPacket) afterLoad() {} +func (x *udpPacket) load(m state.Map) { + m.Load("udpPacketEntry", &x.udpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.Load("tos", &x.tos) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("uniqueID", &x.uniqueID) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("state", &x.state) + m.Save("dstPort", &x.dstPort) + m.Save("v6only", &x.v6only) + m.Save("ttl", &x.ttl) + m.Save("multicastTTL", &x.multicastTTL) + m.Save("multicastAddr", &x.multicastAddr) + m.Save("multicastNICID", &x.multicastNICID) + m.Save("multicastLoop", &x.multicastLoop) + m.Save("reusePort", &x.reusePort) + m.Save("bindToDevice", &x.bindToDevice) + m.Save("broadcast", &x.broadcast) + m.Save("boundBindToDevice", &x.boundBindToDevice) + m.Save("boundPortFlags", &x.boundPortFlags) + m.Save("sendTOS", &x.sendTOS) + m.Save("receiveTOS", &x.receiveTOS) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("multicastMemberships", &x.multicastMemberships) + m.Save("effectiveNetProtos", &x.effectiveNetProtos) +} + +func (x *endpoint) load(m state.Map) { + m.Load("TransportEndpointInfo", &x.TransportEndpointInfo) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("uniqueID", &x.uniqueID) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("state", &x.state) + m.Load("dstPort", &x.dstPort) + m.Load("v6only", &x.v6only) + m.Load("ttl", &x.ttl) + m.Load("multicastTTL", &x.multicastTTL) + m.Load("multicastAddr", &x.multicastAddr) + m.Load("multicastNICID", &x.multicastNICID) + m.Load("multicastLoop", &x.multicastLoop) + m.Load("reusePort", &x.reusePort) + m.Load("bindToDevice", &x.bindToDevice) + m.Load("broadcast", &x.broadcast) + m.Load("boundBindToDevice", &x.boundBindToDevice) + m.Load("boundPortFlags", &x.boundPortFlags) + m.Load("sendTOS", &x.sendTOS) + m.Load("receiveTOS", &x.receiveTOS) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("multicastMemberships", &x.multicastMemberships) + m.Load("effectiveNetProtos", &x.effectiveNetProtos) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *multicastMembership) beforeSave() {} +func (x *multicastMembership) save(m state.Map) { + x.beforeSave() + m.Save("nicID", &x.nicID) + m.Save("multicastAddr", &x.multicastAddr) +} + +func (x *multicastMembership) afterLoad() {} +func (x *multicastMembership) load(m state.Map) { + m.Load("nicID", &x.nicID) + m.Load("multicastAddr", &x.multicastAddr) +} + +func (x *udpPacketList) beforeSave() {} +func (x *udpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *udpPacketList) afterLoad() {} +func (x *udpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *udpPacketEntry) beforeSave() {} +func (x *udpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *udpPacketEntry) afterLoad() {} +func (x *udpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load}) + state.Register("udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load}) + state.Register("udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load}) + state.Register("udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go deleted file mode 100644 index f0ff3fe71..000000000 --- a/pkg/tcpip/transport/udp/udp_test.go +++ /dev/null @@ -1,1780 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_test - -import ( - "bytes" - "context" - "fmt" - "math/rand" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// Addresses and ports used for testing. It is recommended that tests stick to -// using these addresses as it allows using the testFlow helper. -// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' -// represents the remote endpoint. -const ( - v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = v4MappedAddrPrefix + stackAddr - testV4MappedAddr = v4MappedAddrPrefix + testAddr - multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr - broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr - v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 - multicastAddr = "\xe8\x2b\xd3\xea" - multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - broadcastAddr = header.IPv4Broadcast - testTOS = 0x80 - - // defaultMTU is the MTU, in bytes, used throughout the tests, except - // where another value is explicitly used. It is chosen to match the MTU - // of loopback interfaces on linux systems. - defaultMTU = 65536 -) - -// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in -// a packet header. These values are used to populate a header or verify one. -// Note that because they are used in packet headers, the addresses are never in -// a V4-mapped format. -type header4Tuple struct { - srcAddr tcpip.FullAddress - dstAddr tcpip.FullAddress -} - -// testFlow implements a helper type used for sending and receiving test -// packets. A given test flow value defines 1) the socket endpoint used for the -// test and 2) the type of packet send or received on the endpoint. E.g., a -// multicastV6Only flow is a V6 multicast packet passing through a V6-only -// endpoint. The type provides helper methods to characterize the flow (e.g., -// isV4) as well as return a proper header4Tuple for it. -type testFlow int - -const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket -) - -func (flow testFlow) String() string { - switch flow { - case unicastV4: - return "unicastV4" - case unicastV6: - return "unicastV6" - case unicastV6Only: - return "unicastV6Only" - case unicastV4in6: - return "unicastV4in6" - case multicastV4: - return "multicastV4" - case multicastV6: - return "multicastV6" - case multicastV6Only: - return "multicastV6Only" - case multicastV4in6: - return "multicastV4in6" - case broadcast: - return "broadcast" - case broadcastIn6: - return "broadcastIn6" - default: - return "unknown" - } -} - -// packetDirection explains if a flow is incoming (read) or outgoing (write). -type packetDirection int - -const ( - incoming packetDirection = iota - outgoing -) - -// header4Tuple returns the header4Tuple for the given flow and direction. Note -// that the tuple contains no mapped addresses as those only exist at the socket -// level but not at the packet header level. -func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { - var h header4Tuple - if flow.isV4() { - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastAddr - } else if flow.isBroadcast() { - h.dstAddr.Addr = broadcastAddr - } - } else { // IPv6 - if d == outgoing { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - } - } else { - h = header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - } - } - if flow.isMulticast() { - h.dstAddr.Addr = multicastV6Addr - } - } - return h -} - -func (flow testFlow) getMcastAddr() tcpip.Address { - if flow.isV4() { - return multicastAddr - } - return multicastV6Addr -} - -// mapAddrIfApplicable converts the given V4 address into its V4-mapped version -// if it is applicable to the flow. -func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { - if flow.isMapped() { - return v4MappedAddrPrefix + v4Addr - } - return v4Addr -} - -// netProto returns the protocol number used for the network packet. -func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { - if flow.isV4() { - return ipv4.ProtocolNumber - } - return ipv6.ProtocolNumber -} - -// sockProto returns the protocol number used when creating the socket -// endpoint for this flow. -func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { - switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: - return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: - return ipv4.ProtocolNumber - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { - if flow.isV4() { - return checker.IPv4 - } - return checker.IPv6 -} - -func (flow testFlow) isV6() bool { return !flow.isV4() } -func (flow testFlow) isV4() bool { - return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() -} - -func (flow testFlow) isV6Only() bool { - switch flow { - case unicastV6Only, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMulticast() bool { - switch flow { - case multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isBroadcast() bool { - switch flow { - case broadcast, broadcastIn6: - return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -func (flow testFlow) isMapped() bool { - switch flow { - case unicastV4in6, multicastV4in6, broadcastIn6: - return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: - return false - default: - panic(fmt.Sprintf("invalid testFlow given: %d", flow)) - } -} - -type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func newDualTestContext(t *testing.T, mtu uint32) *testContext { - t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - }) -} - -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { - t.Helper() - - s := stack.New(options) - ep := channel.New(256, mtu, "") - wep := stack.LinkEndpoint(ep) - - if testing.Verbose() { - wep = sniffer.New(ep) - } - if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) - } - - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - return &testContext{ - t: t, - s: s, - linkEP: ep, - } -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { - c.t.Helper() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) - if err != nil { - c.t.Fatal("NewEndpoint failed: ", err) - } -} - -func (c *testContext) createEndpointForFlow(flow testFlow) { - c.t.Helper() - - c.createEndpoint(flow.sockProto()) - if flow.isV6Only() { - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) - } - } -} - -// getPacketAndVerify reads a packet from the link endpoint and verifies the -// header against expected values from the given test flow. In addition, it -// calls any extra checker functions provided. -func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { - c.t.Helper() - - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - c.t.Fatalf("Packet wasn't written out") - return nil - } - - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - h := flow.header4Tuple(outgoing) - checkers = append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b -} - -// injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. -func (c *testContext) injectPacket(flow testFlow, payload []byte) { - c.t.Helper() - - h := flow.header4Tuple(incoming) - if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) - } else { - c.injectV6Packet(payload, &h, true /* valid */) - } -} - -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv6(buf) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: l, - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), - }) -} - -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { - // Allocate a buffer for data and headers. - buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - payloadStart := len(buf) - len(payload) - copy(buf[payloadStart:], payload) - - // Initialize the IP header. - ip := header.IPv4(buf) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TOS: testTOS, - TotalLength: uint16(len(buf)), - TTL: 65, - Protocol: uint8(udp.ProtocolNumber), - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - // Initialize the UDP header. - u := header.UDP(buf[header.IPv4MinimumSize:]) - u.Encode(&header.UDPFields{ - SrcPort: h.srcAddr.Port, - DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), - }) - - // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) - - // Calculate the UDP checksum and set it. - xsum = header.Checksum(payload, xsum) - u.SetChecksum(^u.CalculateChecksum(xsum)) - - // Inject packet. - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), - }) -} - -func newPayload() []byte { - return newMinPayload(30) -} - -func newMinPayload(minSize int) []byte { - b := make([]byte, minSize+rand.Intn(100)) - for i := range b { - b[i] = byte(rand.Intn(256)) - } - return b -} - -func TestBindToDeviceOption(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) - } - defer ep.Close() - - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - - // nicIDPtr is used instead of taking the address of NICID literals, which is - // a compiler error. - nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { - return &s - } - - testActions := []struct { - name string - setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption - }{ - {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, - {"BindToExistent", nicIDPtr(321), nil, 321}, - {"UnbindToDevice", nicIDPtr(0), nil, 0}, - } - for _, testAction := range testActions { - t.Run(testAction.name, func(t *testing.T) { - if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) - } - } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %v, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) - } - }) - } -} - -// testReadInternal sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then attempts to read it from the -// UDP endpoint and depending on if this was expected to succeed verifies its -// correctness including any additional checker functions provided. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - - payload := newPayload() - c.injectPacket(flow, payload) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - var addr tcpip.FullAddress - v, cm, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, cm, err = c.ep.Read(&addr) - - case <-time.After(300 * time.Millisecond): - if packetShouldBeDropped { - return // expected to time out - } - c.t.Fatal("timed out waiting for data") - } - } - - if expectReadError && err != nil { - c.checkEndpointReadStats(1, epstats, err) - return - } - - if err != nil { - c.t.Fatal("Read failed:", err) - } - - if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) - } - - // Check the peer address. - h := flow.header4Tuple(incoming) - if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("bad payload: got %x, want %x", v, payload) - } - - // Run any checkers against the ControlMessages. - for _, f := range checkers { - f(c.t, cm) - } - - c.checkEndpointReadStats(1, epstats, err) -} - -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness including any additional checker functions provided. -func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { - c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) -} - -// testFailingRead sends a packet of the given test flow into the stack by -// injecting it into the link endpoint. It then tries to read it from the UDP -// endpoint and expects this to fail. -func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { - c.t.Helper() - testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) -} - -func TestBindEphemeralPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } -} - -func TestBindReservedPort(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - addr, err := c.ep.GetLocalAddress() - if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) - } - - // We can't bind the address reserved by the connected endpoint above. - { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - } - - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - // We can't bind ipv4-any on the port reserved by the connected endpoint - // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) - } - // We can bind an ipv4 address on this port, though. - if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() - - // Once the connected endpoint releases its port reservation, we are able to - // bind ipv4-any once again. - c.ep.Close() - func() { - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - defer ep.Close() - if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } - }() -} - -func TestV4ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV4ReadOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4in6) - - // Bind to local address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4in6) -} - -func TestV6ReadOnV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV6) -} - -// TestV4ReadSelfSource checks that packets coming from a local IP address are -// correctly dropped when handleLocal is true and not otherwise. -func TestV4ReadSelfSource(t *testing.T) { - for _, tt := range []struct { - name string - handleLocal bool - wantErr *tcpip.Error - wantInvalidSource uint64 - }{ - {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, - } { - t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - HandleLocal: tt.handleLocal, - }) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - h.srcAddr = h.dstAddr - - c.injectV4Packet(payload, &h, true /* valid */) - - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { - t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) - } - - if _, _, err := c.ep.Read(nil); err != tt.wantErr { - t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) - } - }) - } -} - -func TestV4ReadOnV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV4) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Test acceptance. - testRead(c, unicastV4) -} - -// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast -// address and receive data sent to that address. -func TestReadOnBoundToMulticast(t *testing.T) { - // FIXME(b/128189410): multicastV4in6 currently doesn't work as - // AddMembershipOption doesn't handle V4in6 addresses. - for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to multicast address. - mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) - if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - // Join multicast group. - ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatal("SetSockOpt failed:", err) - } - - // Check that we receive multicast packets but not unicast or broadcast - // ones. - testRead(c, flow) - testFailingRead(c, broadcast, false /* expectReadError */) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and can receive only broadcast data. -func TestV4ReadOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to broadcast address. - bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) - if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Check that we receive broadcast packets but not unicast ones. - testRead(c, flow) - testFailingRead(c, unicastV4, false /* expectReadError */) - }) - } -} - -// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY -// and receive broadcast and unicast data. -func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { - for _, flow := range []testFlow{broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s (", err) - } - - // Check that we receive both broadcast and unicast packets. - testRead(c, flow) - testRead(c, unicastV4) - }) - } -} - -// testFailingWrite sends a packet of the given test flow into the UDP endpoint -// and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - - payload := buffer.View(newPayload()) - _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - }) - c.checkEndpointWriteStats(1, epstats, gotErr) - if gotErr != wantErr { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) - } -} - -// testWrite sends a packet of the given test flow from the UDP endpoint to the -// flow's destination address:port. It then receives it from the link endpoint -// and verifies its correctness including any additional checker functions -// provided. -func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteInternal(c, flow, true, checkers...) -} - -// testWriteWithoutDestination sends a packet of the given test flow from the -// UDP endpoint without giving a destination address:port. It then receives it -// from the link endpoint and verifies its correctness including any additional -// checker functions provided. -func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - return testWriteInternal(c, flow, false, checkers...) -} - -func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { - c.t.Helper() - // Take a snapshot of the stats to validate them at the end of the test. - epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - - writeOpts := tcpip.WriteOptions{} - if setDest { - h := flow.header4Tuple(outgoing) - writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - writeOpts = tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, - } - } - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - c.checkEndpointWriteStats(1, epstats, err) - // Received the packet and check the payload. - b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP - if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) - } else { - udp = header.UDP(header.IPv6(b).Payload()) - } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } - - return udp.SourcePort() -} - -func testDualWrite(c *testContext) uint16 { - c.t.Helper() - - v4Port := testWrite(c, unicastV4in6) - v6Port := testWrite(c, unicastV6) - if v4Port != v6Port { - c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) - } - - return v4Port -} - -func TestDualWriteUnbound(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) -} - -func TestDualWriteBoundToWildcard(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - p := testDualWrite(c) - if p != stackPort { - c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) - } -} - -func TestDualWriteConnectedToV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testWrite(c, unicastV6) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) - const want = 1 - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { - c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) - } -} - -func TestDualWriteConnectedToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testWrite(c, unicastV4in6) - - // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) -} - -func TestV4WriteOnV6Only(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(unicastV6Only) - - // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) -} - -func TestV6WriteOnBoundToV4Mapped(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to v4 mapped address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) -} - -func TestV6WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v6 address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - testWriteWithoutDestination(c, unicastV6) -} - -func TestV4WriteOnConnected(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Connect to v4 mapped address. - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - testWriteWithoutDestination(c, unicastV4) -} - -// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket -// that is bound to a V4 multicast address. -func TestWriteOnBoundToV4Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a -// socket that is bound to a V4-mapped multicast address. -func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6Multicast(t *testing.T) { - for _, flow := range []testFlow{unicastV6, multicastV6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV6Multicast checks that we can send packets out of a -// V6-only socket that is bound to a V6 multicast address. -func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { - for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V6 mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToBroadcast checks that we can send packets out of a -// socket that is bound to the broadcast address. -func TestWriteOnBoundToBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4 broadcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - - testWrite(c, flow) - }) - } -} - -// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a -// socket that is bound to the V4-mapped broadcast address. -func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { - t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Bind to V4Mapped mcast address. - if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - testWrite(c, flow) - }) - } -} - -func TestReadIncrementsPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - // Create IPv4 UDP endpoint - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - testRead(c, unicastV4) - - var want uint64 = 1 - if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { - c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) - } -} - -func TestWriteIncrementsPacketsSent(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - testDualWrite(c) - - var want uint64 = 2 - if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { - c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) - } -} - -func TestTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - var wantTTL uint8 - if flow.isMulticast() { - wantTTL = multicastTTL - } else { - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - } - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } -} - -func TestSetTTL(t *testing.T) { - for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { - t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - var p stack.NetworkProtocol - if flow.isV4() { - p = ipv4.NewProtocol() - } else { - p = ipv6.NewProtocol() - } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - })) - if err != nil { - t.Fatal(err) - } - ep.Close() - - testWrite(c, flow, checker.TTL(wantTTL)) - }) - } - }) - } -} - -func TestTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = testTOS - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) - } - - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestTOSV6(t *testing.T) { - for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - const tos = testTOS - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - // Test for expected default value. - if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) - } - - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - c.t.Errorf("SetSockOpt failed: %s", err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) - } - - testWrite(c, flow, checker.TOS(tos, 0)) - }) - } -} - -func TestReceiveTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false) - } - - want := true - if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil { - c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err) - } - - got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - if got != want { - c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want) - } - - // Verify that the correct received TOS is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - testRead(c, flow, checker.ReceiveTOS(testTOS)) - }) - } -} - -func TestMulticastInterfaceOption(t *testing.T) { - for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - h := flow.header4Tuple(outgoing) - mcastAddr := h.dstAddr.Addr - localIfAddr := h.srcAddr.Addr - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(flow.sockProto()) - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: flow.mapAddrIfApplicable(mcastAddr), - Port: stackPort, - } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) - } -} - -// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV4UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true, will result in a payload large enough - // so that the final generated IPv4 packet is larger than - // header.IPv4MinimumProcessableDatagramSize. - largePayload bool - }{ - {unicastV4, true, false}, - {unicastV4, true, true}, - {multicastV4, false, false}, - {multicastV4, false, true}, - {broadcast, false, false}, - {broadcast, false, true}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(576) - } - c.injectPacket(tc.flow, payload) - if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } - - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - }) - } -} - -// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination -// Unreachable message when a udp datagram is received on ports for which there -// is no bound udp socket. -func TestV6UnknownDestination(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - testCases := []struct { - flow testFlow - icmpRequired bool - // largePayload if true will result in a payload large enough to - // create an IPv6 packet > header.IPv6MinimumMTU bytes. - largePayload bool - }{ - {unicastV6, true, false}, - {unicastV6, true, true}, - {multicastV6, false, false}, - {multicastV6, false, true}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { - payload := newPayload() - if tc.largePayload { - payload = newMinPayload(1280) - } - c.injectPacket(tc.flow, payload) - if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) - if p, ok := c.linkEP.ReadContext(ctx); ok { - t.Fatalf("unexpected packet received: %+v", p) - } - return - } - - // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) - p, ok := c.linkEP.ReadContext(ctx) - if !ok { - t.Fatalf("packet wasn't written out") - return - } - - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } - - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - }) - } -} - -// TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. -func TestIncrementMalformedPacketsReceived(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - payload := newPayload() - c.t.Helper() - h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) - } -} - -// TestShutdownRead verifies endpoint read shutdown and error -// stats increment on packet receive. -func TestShutdownRead(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) - } - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - testFailingRead(c, unicastV6, true /* expectReadError */) - - var want uint64 = 1 - if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { - t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) - } -} - -// TestShutdownWrite verifies endpoint write shutdown and error -// stats increment on packet write. -func TestShutdownWrite(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) - } - - if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) - } - - testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) -} - -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil: - want.PacketsSent.IncrementBy(incr) - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: - want.WriteErrors.InvalidArgs.IncrementBy(incr) - case tcpip.ErrClosedForSend: - want.WriteErrors.WriteClosed.IncrementBy(incr) - case tcpip.ErrInvalidEndpointState: - want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoLinkAddress: - want.SendErrors.NoLinkAddr.IncrementBy(incr) - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: - want.SendErrors.NoRoute.IncrementBy(incr) - default: - want.SendErrors.SendToNetworkFailed.IncrementBy(incr) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} - -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { - got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil, tcpip.ErrWouldBlock: - case tcpip.ErrClosedForReceive: - want.ReadErrors.ReadClosed.IncrementBy(incr) - default: - c.t.Errorf("Endpoint error missing stats update err %v", err) - } - if got != want { - c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) - } -} |