summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/BUILD63
-rw-r--r--pkg/tcpip/transport/udp/udp_packet_list.go221
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go229
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go2565
4 files changed, 450 insertions, 2628 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
deleted file mode 100644
index 153e8c950..000000000
--- a/pkg/tcpip/transport/udp/BUILD
+++ /dev/null
@@ -1,63 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "udp_packet_list",
- out = "udp_packet_list.go",
- package = "udp",
- prefix = "udpPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*udpPacket",
- "Linker": "*udpPacket",
- },
-)
-
-go_library(
- name = "udp",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "forwarder.go",
- "protocol.go",
- "udp_packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/raw",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "udp_x_test",
- size = "small",
- srcs = ["udp_test.go"],
- deps = [
- ":udp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/icmp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go
new file mode 100644
index 000000000..c396f77c9
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_packet_list.go
@@ -0,0 +1,221 @@
+package udp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type udpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type udpPacketList struct {
+ head *udpPacket
+ tail *udpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *udpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *udpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *udpPacketList) Front() *udpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *udpPacketList) Back() *udpPacket {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *udpPacketList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (udpPacketElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *udpPacketList) PushFront(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *udpPacketList) PushBack(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *udpPacketList) PushBackList(m *udpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
+ bLinker := udpPacketElementMapper{}.linkerFor(b)
+ eLinker := udpPacketElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
+ aLinker := udpPacketElementMapper{}.linkerFor(a)
+ eLinker := udpPacketElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *udpPacketList) Remove(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type udpPacketEntry struct {
+ next *udpPacket
+ prev *udpPacket
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) Next() *udpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) Prev() *udpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) SetNext(elem *udpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
new file mode 100644
index 000000000..16900d0f9
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -0,0 +1,229 @@
+// automatically generated by stateify.
+
+package udp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (u *udpPacket) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacket"
+}
+
+func (u *udpPacket) StateFields() []string {
+ return []string{
+ "udpPacketEntry",
+ "senderAddress",
+ "destinationAddress",
+ "packetInfo",
+ "data",
+ "timestamp",
+ "tos",
+ }
+}
+
+func (u *udpPacket) beforeSave() {}
+
+func (u *udpPacket) StateSave(stateSinkObject state.Sink) {
+ u.beforeSave()
+ var dataValue buffer.VectorisedView = u.saveData()
+ stateSinkObject.SaveValue(4, dataValue)
+ stateSinkObject.Save(0, &u.udpPacketEntry)
+ stateSinkObject.Save(1, &u.senderAddress)
+ stateSinkObject.Save(2, &u.destinationAddress)
+ stateSinkObject.Save(3, &u.packetInfo)
+ stateSinkObject.Save(5, &u.timestamp)
+ stateSinkObject.Save(6, &u.tos)
+}
+
+func (u *udpPacket) afterLoad() {}
+
+func (u *udpPacket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &u.udpPacketEntry)
+ stateSourceObject.Load(1, &u.senderAddress)
+ stateSourceObject.Load(2, &u.destinationAddress)
+ stateSourceObject.Load(3, &u.packetInfo)
+ stateSourceObject.Load(5, &u.timestamp)
+ stateSourceObject.Load(6, &u.tos)
+ stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { u.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "TransportEndpointInfo",
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "uniqueID",
+ "rcvReady",
+ "rcvList",
+ "rcvBufSizeMax",
+ "rcvBufSize",
+ "rcvClosed",
+ "state",
+ "dstPort",
+ "ttl",
+ "multicastTTL",
+ "multicastAddr",
+ "multicastNICID",
+ "portFlags",
+ "lastError",
+ "boundBindToDevice",
+ "boundPortFlags",
+ "sendTOS",
+ "shutdownFlags",
+ "multicastMemberships",
+ "effectiveNetProtos",
+ "owner",
+ "ops",
+ }
+}
+
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
+ stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
+ stateSinkObject.Save(0, &e.TransportEndpointInfo)
+ stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(2, &e.waiterQueue)
+ stateSinkObject.Save(3, &e.uniqueID)
+ stateSinkObject.Save(4, &e.rcvReady)
+ stateSinkObject.Save(5, &e.rcvList)
+ stateSinkObject.Save(7, &e.rcvBufSize)
+ stateSinkObject.Save(8, &e.rcvClosed)
+ stateSinkObject.Save(9, &e.state)
+ stateSinkObject.Save(10, &e.dstPort)
+ stateSinkObject.Save(11, &e.ttl)
+ stateSinkObject.Save(12, &e.multicastTTL)
+ stateSinkObject.Save(13, &e.multicastAddr)
+ stateSinkObject.Save(14, &e.multicastNICID)
+ stateSinkObject.Save(15, &e.portFlags)
+ stateSinkObject.Save(16, &e.lastError)
+ stateSinkObject.Save(17, &e.boundBindToDevice)
+ stateSinkObject.Save(18, &e.boundPortFlags)
+ stateSinkObject.Save(19, &e.sendTOS)
+ stateSinkObject.Save(20, &e.shutdownFlags)
+ stateSinkObject.Save(21, &e.multicastMemberships)
+ stateSinkObject.Save(22, &e.effectiveNetProtos)
+ stateSinkObject.Save(23, &e.owner)
+ stateSinkObject.Save(24, &e.ops)
+}
+
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.TransportEndpointInfo)
+ stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(2, &e.waiterQueue)
+ stateSourceObject.Load(3, &e.uniqueID)
+ stateSourceObject.Load(4, &e.rcvReady)
+ stateSourceObject.Load(5, &e.rcvList)
+ stateSourceObject.Load(7, &e.rcvBufSize)
+ stateSourceObject.Load(8, &e.rcvClosed)
+ stateSourceObject.Load(9, &e.state)
+ stateSourceObject.Load(10, &e.dstPort)
+ stateSourceObject.Load(11, &e.ttl)
+ stateSourceObject.Load(12, &e.multicastTTL)
+ stateSourceObject.Load(13, &e.multicastAddr)
+ stateSourceObject.Load(14, &e.multicastNICID)
+ stateSourceObject.Load(15, &e.portFlags)
+ stateSourceObject.Load(16, &e.lastError)
+ stateSourceObject.Load(17, &e.boundBindToDevice)
+ stateSourceObject.Load(18, &e.boundPortFlags)
+ stateSourceObject.Load(19, &e.sendTOS)
+ stateSourceObject.Load(20, &e.shutdownFlags)
+ stateSourceObject.Load(21, &e.multicastMemberships)
+ stateSourceObject.Load(22, &e.effectiveNetProtos)
+ stateSourceObject.Load(23, &e.owner)
+ stateSourceObject.Load(24, &e.ops)
+ stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (m *multicastMembership) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.multicastMembership"
+}
+
+func (m *multicastMembership) StateFields() []string {
+ return []string{
+ "nicID",
+ "multicastAddr",
+ }
+}
+
+func (m *multicastMembership) beforeSave() {}
+
+func (m *multicastMembership) StateSave(stateSinkObject state.Sink) {
+ m.beforeSave()
+ stateSinkObject.Save(0, &m.nicID)
+ stateSinkObject.Save(1, &m.multicastAddr)
+}
+
+func (m *multicastMembership) afterLoad() {}
+
+func (m *multicastMembership) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &m.nicID)
+ stateSourceObject.Load(1, &m.multicastAddr)
+}
+
+func (l *udpPacketList) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacketList"
+}
+
+func (l *udpPacketList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *udpPacketList) beforeSave() {}
+
+func (l *udpPacketList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *udpPacketList) afterLoad() {}
+
+func (l *udpPacketList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *udpPacketEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacketEntry"
+}
+
+func (e *udpPacketEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *udpPacketEntry) beforeSave() {}
+
+func (e *udpPacketEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *udpPacketEntry) afterLoad() {}
+
+func (e *udpPacketEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*udpPacket)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*multicastMembership)(nil))
+ state.Register((*udpPacketList)(nil))
+ state.Register((*udpPacketEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
deleted file mode 100644
index c8126b51b..000000000
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ /dev/null
@@ -1,2565 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package udp_test
-
-import (
- "bytes"
- "context"
- "fmt"
- "io/ioutil"
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// Addresses and ports used for testing. It is recommended that tests stick to
-// using these addresses as it allows using the testFlow helper.
-// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
-// represents the remote endpoint.
-const (
- v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
- stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
- testV4MappedAddr = v4MappedAddrPrefix + testAddr
- multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
- broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
- v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
- invalidPort = 8192
- multicastAddr = "\xe8\x2b\xd3\xea"
- multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- broadcastAddr = header.IPv4Broadcast
- testTOS = 0x80
-
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65536
-)
-
-// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
-// a packet header. These values are used to populate a header or verify one.
-// Note that because they are used in packet headers, the addresses are never in
-// a V4-mapped format.
-type header4Tuple struct {
- srcAddr tcpip.FullAddress
- dstAddr tcpip.FullAddress
-}
-
-// testFlow implements a helper type used for sending and receiving test
-// packets. A given test flow value defines 1) the socket endpoint used for the
-// test and 2) the type of packet send or received on the endpoint. E.g., a
-// multicastV6Only flow is a V6 multicast packet passing through a V6-only
-// endpoint. The type provides helper methods to characterize the flow (e.g.,
-// isV4) as well as return a proper header4Tuple for it.
-type testFlow int
-
-const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
- reverseMulticast4 // V4 multicast src. Must fail.
- reverseMulticast6 // V6 multicast src. Must fail.
-)
-
-func (flow testFlow) String() string {
- switch flow {
- case unicastV4:
- return "unicastV4"
- case unicastV6:
- return "unicastV6"
- case unicastV6Only:
- return "unicastV6Only"
- case unicastV4in6:
- return "unicastV4in6"
- case multicastV4:
- return "multicastV4"
- case multicastV6:
- return "multicastV6"
- case multicastV6Only:
- return "multicastV6Only"
- case multicastV4in6:
- return "multicastV4in6"
- case broadcast:
- return "broadcast"
- case broadcastIn6:
- return "broadcastIn6"
- case reverseMulticast4:
- return "reverseMulticast4"
- case reverseMulticast6:
- return "reverseMulticast6"
- default:
- return "unknown"
- }
-}
-
-// packetDirection explains if a flow is incoming (read) or outgoing (write).
-type packetDirection int
-
-const (
- incoming packetDirection = iota
- outgoing
-)
-
-// header4Tuple returns the header4Tuple for the given flow and direction. Note
-// that the tuple contains no mapped addresses as those only exist at the socket
-// level but not at the packet header level.
-func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
- var h header4Tuple
- if flow.isV4() {
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastAddr
- } else if flow.isBroadcast() {
- h.dstAddr.Addr = broadcastAddr
- }
- } else { // IPv6
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastV6Addr
- }
- }
- if flow.isReverseMulticast() {
- h.srcAddr.Addr = flow.getMcastAddr()
- }
- return h
-}
-
-func (flow testFlow) getMcastAddr() tcpip.Address {
- if flow.isV4() {
- return multicastAddr
- }
- return multicastV6Addr
-}
-
-// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
-// if it is applicable to the flow.
-func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
- if flow.isMapped() {
- return v4MappedAddrPrefix + v4Addr
- }
- return v4Addr
-}
-
-// netProto returns the protocol number used for the network packet.
-func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
- if flow.isV4() {
- return ipv4.ProtocolNumber
- }
- return ipv6.ProtocolNumber
-}
-
-// sockProto returns the protocol number used when creating the socket
-// endpoint for this flow.
-func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
- switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
- return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast, reverseMulticast4:
- return ipv4.ProtocolNumber
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
- if flow.isV4() {
- return checker.IPv4
- }
- return checker.IPv6
-}
-
-func (flow testFlow) isV6() bool { return !flow.isV4() }
-func (flow testFlow) isV4() bool {
- return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
-}
-
-func (flow testFlow) isV6Only() bool {
- switch flow {
- case unicastV6Only, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMulticast() bool {
- switch flow {
- case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isBroadcast() bool {
- switch flow {
- case broadcast, broadcastIn6:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMapped() bool {
- switch flow {
- case unicastV4in6, multicastV4in6, broadcastIn6:
- return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isReverseMulticast() bool {
- switch flow {
- case reverseMulticast4, reverseMulticast6:
- return true
- default:
- return false
- }
-}
-
-type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
-
- ep tcpip.Endpoint
- wq waiter.Queue
-}
-
-func newDualTestContext(t *testing.T, mtu uint32) *testContext {
- t.Helper()
- return newDualTestContextWithOptions(t, mtu, stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
- HandleLocal: true,
- })
-}
-
-func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
- t.Helper()
-
- s := stack.New(options)
- ep := channel.New(256, mtu, "")
- wep := stack.LinkEndpoint(ep)
-
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
- }
-
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
- }
-
- if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- return &testContext{
- t: t,
- s: s,
- linkEP: ep,
- }
-}
-
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
- c.t.Helper()
-
- var err tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
- if err != nil {
- c.t.Fatal("NewEndpoint failed: ", err)
- }
-}
-
-func (c *testContext) createEndpointForFlow(flow testFlow) {
- c.t.Helper()
-
- c.createEndpoint(flow.sockProto())
- if flow.isV6Only() {
- c.ep.SocketOptions().SetV6Only(true)
- } else if flow.isBroadcast() {
- c.ep.SocketOptions().SetBroadcast(true)
- }
-}
-
-// getPacketAndVerify reads a packet from the link endpoint and verifies the
-// header against expected values from the given test flow. In addition, it
-// calls any extra checker functions provided.
-func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
- c.t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
- c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- h := flow.header4Tuple(outgoing)
- checkers = append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-}
-
-// injectPacket creates a packet of the given flow and with the given payload,
-// and injects it into the link endpoint. If badChecksum is true, the packet has
-// a bad checksum in the UDP header.
-func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) {
- c.t.Helper()
-
- h := flow.header4Tuple(incoming)
- if flow.isV4() {
- buf := c.buildV4Packet(payload, &h)
- if badChecksum {
- // Invalidate the UDP header checksum field, taking care to avoid
- // overflow to zero, which would disable checksum validation.
- for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
- u.SetChecksum(u.Checksum() + 1)
- if u.Checksum() != 0 {
- break
- }
- }
- }
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- } else {
- buf := c.buildV6Packet(payload, &h)
- if badChecksum {
- // Invalidate the UDP header checksum field (Unlike IPv4, zero is
- // a valid checksum value for IPv6 so no need to avoid it).
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetChecksum(u.Checksum() + 1)
- }
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-}
-
-// buildV6Packet creates a V6 test packet with the given payload and header
-// values in a buffer.
-func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- return buf
-}
-
-// buildV4Packet creates a V4 test packet with the given payload and header
-// values in a buffer.
-func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TOS: testTOS,
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- return buf
-}
-
-func newPayload() []byte {
- return newMinPayload(30)
-}
-
-func newMinPayload(minSize int) []byte {
- b := make([]byte, minSize+rand.Intn(100))
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
- return b
-}
-
-func TestBindToDeviceOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}})
-
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer ep.Close()
-
- opts := stack.NICOptions{Name: "my_device"}
- if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
- }
-
- // nicIDPtr is used instead of taking the address of NICID literals, which is
- // a compiler error.
- nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
- return &s
- }
-
- testActions := []struct {
- name string
- setBindToDevice *tcpip.NICID
- setBindToDeviceError tcpip.Error
- getBindToDevice int32
- }{
- {"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
- {"BindToExistent", nicIDPtr(321), nil, 321},
- {"UnbindToDevice", nicIDPtr(0), nil, 0},
- }
- for _, testAction := range testActions {
- t.Run(testAction.name, func(t *testing.T) {
- if testAction.setBindToDevice != nil {
- bindToDevice := int32(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
- }
- }
- bindToDevice := ep.SocketOptions().GetBindToDevice()
- if bindToDevice != testAction.getBindToDevice {
- t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
- }
- })
- }
-}
-
-// testReadInternal sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then attempts to read it from the
-// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness including any additional checker functions provided.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
-
- payload := newPayload()
- c.injectPacket(flow, payload, false)
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.EventIn)
- defer c.wq.EventUnregister(&we)
-
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- var buf bytes.Buffer
- res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- // Wait for data to become available.
- select {
- case <-ch:
- res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
-
- case <-time.After(300 * time.Millisecond):
- if packetShouldBeDropped {
- return // expected to time out
- }
- c.t.Fatal("timed out waiting for data")
- }
- }
-
- if expectReadError && err != nil {
- c.checkEndpointReadStats(1, epstats, err)
- return
- }
-
- if err != nil {
- c.t.Fatal("Read failed:", err)
- }
-
- if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
- }
-
- // Check the read result.
- h := flow.header4Tuple(incoming)
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages", // ControlMessages will be checked later.
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
- }
-
- // Check the payload.
- v := buf.Bytes()
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("got payload = %x, want = %x", v, payload)
- }
-
- // Run any checkers against the ControlMessages.
- for _, f := range checkers {
- f(c.t, res.ControlMessages)
- }
-
- c.checkEndpointReadStats(1, epstats, err)
-}
-
-// testRead sends a packet of the given test flow into the stack by injecting it
-// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness including any additional checker functions provided.
-func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
-}
-
-// testFailingRead sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then tries to read it from the UDP
-// endpoint and expects this to fail.
-func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
- c.t.Helper()
- testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
-}
-
-func TestBindEphemeralPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
-}
-
-func TestBindReservedPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- addr, err := c.ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress failed: %s", err)
- }
-
- // We can't bind the address reserved by the connected endpoint above.
- {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- {
- err := ep.Bind(addr)
- if _, ok := err.(*tcpip.ErrPortInUse); !ok {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
- }
- }
- }
-
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- // We can't bind ipv4-any on the port reserved by the connected endpoint
- // above, since the endpoint is dual-stack.
- {
- err := ep.Bind(tcpip.FullAddress{Port: addr.Port})
- if _, ok := err.(*tcpip.ErrPortInUse); !ok {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
- }
- }
- // We can bind an ipv4 address on this port, though.
- if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
- }()
-
- // Once the connected endpoint releases its port reservation, we are able to
- // bind ipv4-any once again.
- c.ep.Close()
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
- }()
-}
-
-func TestV4ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to local address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV6ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV6)
-}
-
-// TestV4ReadSelfSource checks that packets coming from a local IP address are
-// correctly dropped when handleLocal is true and not otherwise.
-func TestV4ReadSelfSource(t *testing.T) {
- for _, tt := range []struct {
- name string
- handleLocal bool
- wantErr tcpip.Error
- wantInvalidSource uint64
- }{
- {"HandleLocal", false, nil, 0},
- {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
- } {
- t.Run(tt.name, func(t *testing.T) {
- c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- HandleLocal: tt.handleLocal,
- })
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- h.srcAddr = h.dstAddr
-
- buf := c.buildV4Packet(payload, &h)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
- t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
- }
-
- if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr {
- t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
- }
- })
- }
-}
-
-func TestV4ReadOnV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4)
-}
-
-// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
-// address and receive data sent to that address.
-func TestReadOnBoundToMulticast(t *testing.T) {
- // FIXME(b/128189410): multicastV4in6 currently doesn't work as
- // AddMembershipOption doesn't handle V4in6 addresses.
- for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to multicast address.
- mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
- if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- // Join multicast group.
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
- }
-
- // Check that we receive multicast packets but not unicast or broadcast
- // ones.
- testRead(c, flow)
- testFailingRead(c, broadcast, false /* expectReadError */)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
-// address and can receive only broadcast data.
-func TestV4ReadOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to broadcast address.
- bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
- if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Check that we receive broadcast packets but not unicast ones.
- testRead(c, flow)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestReadFromMulticast checks that an endpoint will NOT receive a packet
-// that was sent with multicast SOURCE address.
-func TestReadFromMulticast(t *testing.T) {
- for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- testFailingRead(c, flow, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
-// and receive broadcast and unicast data.
-func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s (", err)
- }
-
- // Check that we receive both broadcast and unicast packets.
- testRead(c, flow)
- testRead(c, unicastV4)
- })
- }
-}
-
-// testFailingWrite sends a packet of the given test flow into the UDP endpoint
-// and verifies it fails with the provided error code.
-func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
-
- var r bytes.Reader
- r.Reset(newPayload())
- _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- })
- c.checkEndpointWriteStats(1, epstats, gotErr)
- if gotErr != wantErr {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
- }
-}
-
-// testWrite sends a packet of the given test flow from the UDP endpoint to the
-// flow's destination address:port. It then receives it from the link endpoint
-// and verifies its correctness including any additional checker functions
-// provided.
-func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteAndVerifyInternal(c, flow, true, checkers...)
-}
-
-// testWriteWithoutDestination sends a packet of the given test flow from the
-// UDP endpoint without giving a destination address:port. It then receives it
-// from the link endpoint and verifies its correctness including any additional
-// checker functions provided.
-func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteAndVerifyInternal(c, flow, false, checkers...)
-}
-
-func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- writeOpts := tcpip.WriteOptions{}
- if setDest {
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
- writeOpts = tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- }
- }
- var r bytes.Reader
- payload := newPayload()
- r.Reset(payload)
- n, err := c.ep.Write(&r, writeOpts)
- if err != nil {
- c.t.Fatalf("Write failed: %s", err)
- }
- if n != int64(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
- c.checkEndpointWriteStats(1, epstats, err)
- return payload
-}
-
-func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- payload := testWriteNoVerify(c, flow, setDest)
- // Received the packet and check the payload.
- b := c.getPacketAndVerify(flow, checkers...)
- var udp header.UDP
- if flow.isV4() {
- udp = header.UDP(header.IPv4(b).Payload())
- } else {
- udp = header.UDP(header.IPv6(b).Payload())
- }
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
- }
-
- return udp.SourcePort()
-}
-
-func testDualWrite(c *testContext) uint16 {
- c.t.Helper()
-
- v4Port := testWrite(c, unicastV4in6)
- v6Port := testWrite(c, unicastV6)
- if v4Port != v6Port {
- c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
- }
-
- return v4Port
-}
-
-func TestDualWriteUnbound(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-}
-
-func TestDualWriteBoundToWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- p := testDualWrite(c)
- if p != stackPort {
- c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
- }
-}
-
-func TestDualWriteConnectedToV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, unicastV6)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{})
- const want = 1
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
- c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
- }
-}
-
-func TestDualWriteConnectedToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, unicastV4in6)
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
-}
-
-func TestV4WriteOnV6Only(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6Only)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{})
-}
-
-func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to v4 mapped address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
-}
-
-func TestV6WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- testWriteWithoutDestination(c, unicastV6)
-}
-
-func TestV4WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- testWriteWithoutDestination(c, unicastV4)
-}
-
-func TestWriteOnConnectedInvalidPort(t *testing.T) {
- protocols := map[string]tcpip.NetworkProtocolNumber{
- "ipv4": ipv4.ProtocolNumber,
- "ipv6": ipv6.ProtocolNumber,
- }
- for name, pn := range protocols {
- t.Run(name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(pn)
- if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
- writeOpts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
- }
- var r bytes.Reader
- payload := newPayload()
- r.Reset(payload)
- n, err := c.ep.Write(&r, writeOpts)
- if err != nil {
- c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
- }
- if got, want := n, int64(len(payload)); got != want {
- c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
- }
-
- {
- err := c.ep.LastError()
- if _, ok := err.(*tcpip.ErrConnectionRefused); !ok {
- c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
- }
- }
- })
- }
-}
-
-// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
-// that is bound to a V4 multicast address.
-func TestWriteOnBoundToV4Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
-// socket that is bound to a V4-mapped multicast address.
-func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6, multicastV6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// V6-only socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToBroadcast checks that we can send packets out of a
-// socket that is bound to the broadcast address.
-func TestWriteOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 broadcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
-// socket that is bound to the V4-mapped broadcast address.
-func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-func TestReadIncrementsPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- // Create IPv4 UDP endpoint
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testRead(c, unicastV4)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
- c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
- }
-}
-
-func TestReadIPPacketInfo(t *testing.T) {
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedLocalAddr tcpip.Address
- expectedDestAddr tcpip.Address
- }{
- {
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedLocalAddr: stackAddr,
- expectedDestAddr: stackAddr,
- },
- {
- name: "IPv4 multicast",
- proto: header.IPv4ProtocolNumber,
- flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastAddr,
- expectedDestAddr: multicastAddr,
- },
- {
- name: "IPv4 broadcast",
- proto: header.IPv4ProtocolNumber,
- flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: broadcastAddr,
- expectedDestAddr: broadcastAddr,
- },
- {
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedLocalAddr: stackV6Addr,
- expectedDestAddr: stackV6Addr,
- },
- {
- name: "IPv6 multicast",
- proto: header.IPv6ProtocolNumber,
- flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastV6Addr,
- expectedDestAddr: multicastV6Addr,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(test.proto)
-
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- if test.flow.isMulticast() {
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
- }
- }
-
- c.ep.SocketOptions().SetReceivePacketInfo(true)
-
- testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: 1,
- LocalAddr: test.expectedLocalAddr,
- DestinationAddr: test.expectedDestAddr,
- }))
-
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
- t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestReadRecvOriginalDstAddr(t *testing.T) {
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedOriginalDstAddr tcpip.FullAddress
- }{
- {
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort},
- },
- {
- name: "IPv4 multicast",
- proto: header.IPv4ProtocolNumber,
- flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort},
- },
- {
- name: "IPv4 broadcast",
- proto: header.IPv4ProtocolNumber,
- flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort},
- },
- {
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort},
- },
- {
- name: "IPv6 multicast",
- proto: header.IPv6ProtocolNumber,
- flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(test.proto)
-
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%#v): %s", bindAddr, err)
- }
-
- if test.flow.isMulticast() {
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
- }
- }
-
- c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
-
- testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
-
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
- t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestWriteIncrementsPacketsSent(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-
- var want uint64 = 2
- if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
- c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
- }
-}
-
-func TestNoChecksum(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Disable the checksum generation.
- c.ep.SocketOptions().SetNoChecksum(true)
- // This option is effective on IPv4 only.
- testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
-
- // Enable the checksum generation.
- c.ep.SocketOptions().SetNoChecksum(false)
- testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
- })
- }
-}
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.NetworkInterface
-}
-
-func (*testInterface) ID() tcpip.NICID {
- return 0
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func TestTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const multicastTTL = 42
- if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
- c.t.Fatalf("SetSockOptInt failed: %s", err)
- }
-
- var wantTTL uint8
- if flow.isMulticast() {
- wantTTL = multicastTTL
- } else {
- var p stack.NetworkProtocolFactory
- var n tcpip.NetworkProtocolNumber
- if flow.isV4() {
- p = ipv4.NewProtocol
- n = ipv4.ProtocolNumber
- } else {
- p = ipv6.NewProtocol
- n = ipv6.ProtocolNumber
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{p},
- })
- ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil)
- wantTTL = ep.DefaultTTL()
- ep.Close()
- }
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
-}
-
-func TestSetTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
- t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
- c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
- }
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
- })
- }
-}
-
-func TestSetTOS(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tos = testTOS
- v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
- }
-
- if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
- c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
- }
-
- v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
- }
-
- if v != tos {
- c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
- }
-
- testWrite(c, flow, checker.TOS(tos, 0))
- })
- }
-}
-
-func TestSetTClass(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tClass = testTOS
- v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
- }
-
- if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
- c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
- }
-
- v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
- }
-
- if v != tClass {
- c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
- }
-
- // The header getter for TClass is called TOS, so use that checker.
- testWrite(c, flow, checker.TOS(tClass, 0))
- })
- }
-}
-
-func TestReceiveTosTClass(t *testing.T) {
- const RcvTOSOpt = "ReceiveTosOption"
- const RcvTClassOpt = "ReceiveTClassOption"
-
- testCases := []struct {
- name string
- tests []testFlow
- }{
- {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
- {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
- }
- for _, testCase := range testCases {
- for _, flow := range testCase.tests {
- t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
- name := testCase.name
-
- var optionGetter func() bool
- var optionSetter func(bool)
- switch name {
- case RcvTOSOpt:
- optionGetter = c.ep.SocketOptions().GetReceiveTOS
- optionSetter = c.ep.SocketOptions().SetReceiveTOS
- case RcvTClassOpt:
- optionGetter = c.ep.SocketOptions().GetReceiveTClass
- optionSetter = c.ep.SocketOptions().SetReceiveTClass
- default:
- t.Fatalf("unkown test variant: %s", name)
- }
-
- // Verify that setting and reading the option works.
- v := optionGetter()
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
- }
-
- want := true
- optionSetter(want)
-
- got := optionGetter()
- if got != want {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
- }
-
- // Verify that the correct received TOS or TClass is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
- switch name {
- case RcvTClassOpt:
- testRead(c, flow, checker.ReceiveTClass(testTOS))
- case RcvTOSOpt:
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- default:
- t.Fatalf("unknown test variant: %s", name)
- }
- })
- }
- }
-}
-
-func TestMulticastInterfaceOption(t *testing.T) {
- for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- h := flow.header4Tuple(outgoing)
- mcastAddr := h.dstAddr.Addr
- localIfAddr := h.srcAddr.Addr
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(flow.sockProto())
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: flow.mapAddrIfApplicable(mcastAddr),
- Port: stackPort,
- }
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
- }
-
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
- } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
- c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
- }
-}
-
-// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV4UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true, will result in a payload large enough
- // so that the final generated IPv4 packet is larger than
- // header.IPv4MinimumProcessableDatagramSize.
- largePayload bool
- // badChecksum if true, will set an invalid checksum in the
- // header.
- badChecksum bool
- }{
- {unicastV4, true, false, false},
- {unicastV4, true, true, false},
- {unicastV4, false, false, true},
- {unicastV4, false, true, true},
- {multicastV4, false, false, false},
- {multicastV4, false, true, false},
- {broadcast, false, false, false},
- {broadcast, false, true, false},
- }
- checksumErrors := uint64(0)
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(576)
- }
- c.injectPacket(tc.flow, payload, tc.badChecksum)
- if tc.badChecksum {
- checksumErrors++
- if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
- t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- }
- if !tc.icmpRequired {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- if p, ok := c.linkEP.ReadContext(ctx); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
-
- // We need to compare the included data part of the UDP packet that is in
- // the ICMP packet with the matching original data.
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantLen := len(payload)
- if tc.largePayload {
- // To work out the data size we need to simulate what the sender would
- // have done. The wanted size is the total available minus the sum of
- // the headers in the UDP AND ICMP packets, given that we know the test
- // had only a minimal IP header but the ICMP sender will have allowed
- // for a maximally sized packet header.
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
- }
-
- // In the case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- // Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- })
- }
-}
-
-// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV6UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true will result in a payload large enough to
- // create an IPv6 packet > header.IPv6MinimumMTU bytes.
- largePayload bool
- // badChecksum if true, will set an invalid checksum in the
- // header.
- badChecksum bool
- }{
- {unicastV6, true, false, false},
- {unicastV6, true, true, false},
- {unicastV6, false, false, true},
- {unicastV6, false, true, true},
- {multicastV6, false, false, false},
- {multicastV6, false, true, false},
- }
- checksumErrors := uint64(0)
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(1280)
- }
- c.injectPacket(tc.flow, payload, tc.badChecksum)
- if tc.badChecksum {
- checksumErrors++
- if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
- t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- }
- if !tc.icmpRequired {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- if p, ok := c.linkEP.ReadContext(ctx); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
-
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- })
- }
-}
-
-// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats are incremented.
-func TestIncrementMalformedPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
-
- // Invalidate the UDP header length field.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetLength(u.Length() + 1)
-
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
- }
-}
-
-// TestShortHeader verifies that when a packet with a too-short UDP header is
-// received, the malformed received global stat gets incremented.
-func TestShortHeader(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- h := unicastV6.header4Tuple(incoming)
-
- // Allocate a buffer for an IPv6 and too-short UDP header.
- const udpSize = header.UDPMinimumSize - 1
- buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
-
- // Initialize the UDP header.
- udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
- udpHdr.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: header.UDPMinimumSize,
- })
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
- udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
- // Copy all but the last byte of the UDP header into the packet.
- copy(buf[header.IPv6MinimumSize:], udpHdr)
-
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
- t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
- }
-}
-
-// TestBadChecksumErrors verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestBadChecksumErrors(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV6} {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(flow.sockProto())
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- c.injectPacket(flow, payload, true /* badChecksum */)
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
- }
-}
-
-// TestPayloadModifiedV4 verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestPayloadModifiedV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv4.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- buf := c.buildV4Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be
- // incorrect.
- buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestPayloadModifiedV6 verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestPayloadModifiedV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be
- // incorrect.
- buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestChecksumZeroV4 verifies if the checksum value is zero, global and
-// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
-func TestChecksumZeroV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv4.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- buf := c.buildV4Packet(payload, &h)
- // Set the checksum field in the UDP header to zero.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 0
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestChecksumZeroV6 verifies if the checksum value is zero, global and
-// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
-func TestChecksumZeroV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
- // Set the checksum field in the UDP header to zero.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestShutdownRead verifies endpoint read shutdown and error
-// stats increment on packet receive.
-func TestShutdownRead(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- testFailingRead(c, unicastV6, true /* expectReadError */)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
- }
-}
-
-// TestShutdownWrite verifies endpoint write shutdown and error
-// stats increment on packet write.
-func TestShutdownWrite(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{})
-}
-
-func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err.(type) {
- case nil:
- want.PacketsSent.IncrementBy(incr)
- case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
- want.WriteErrors.InvalidArgs.IncrementBy(incr)
- case *tcpip.ErrClosedForSend:
- want.WriteErrors.WriteClosed.IncrementBy(incr)
- case *tcpip.ErrInvalidEndpointState:
- want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
- case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
- want.SendErrors.NoRoute.IncrementBy(incr)
- default:
- want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}
-
-func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- case *tcpip.ErrClosedForReceive:
- want.ReadErrors.ReadClosed.IncrementBy(incr)
- default:
- c.t.Errorf("Endpoint error missing stats update err %v", err)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}
-
-func TestOutgoingSubnetBroadcast(t *testing.T) {
- const nicID1 = 1
-
- ipv4Addr := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 24,
- }
- ipv4Subnet := ipv4Addr.Subnet()
- ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
- ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 31,
- }
- ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
- ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
- ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 32,
- }
- ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
- ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
- ipv6Addr := tcpip.AddressWithPrefix{
- Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- PrefixLen: 64,
- }
- ipv6Subnet := ipv6Addr.Subnet()
- ipv6SubnetBcast := ipv6Subnet.Broadcast()
- remNetAddr := tcpip.AddressWithPrefix{
- Address: "\x64\x0a\x7b\x18",
- PrefixLen: 24,
- }
- remNetSubnet := remNetAddr.Subnet()
- remNetSubnetBcast := remNetSubnet.Broadcast()
-
- tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- requiresBroadcastOpt bool
- }{
- {
- name: "IPv4 Broadcast to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4SubnetBcast,
- requiresBroadcastOpt: true,
- },
- {
- name: "IPv4 Broadcast to local /31 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix31,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet31,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet31Bcast,
- requiresBroadcastOpt: false,
- },
- {
- name: "IPv4 Broadcast to local /32 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix32,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet32,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet32Bcast,
- requiresBroadcastOpt: false,
- },
- // IPv6 has no notion of a broadcast.
- {
- name: "IPv6 'Broadcast' to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: ipv6Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv6Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv6SubnetBcast,
- requiresBroadcastOpt: false,
- },
- {
- name: "IPv4 Broadcast to remote subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: remNetSubnet,
- Gateway: ipv4Gateway,
- NIC: nicID1,
- },
- },
- remoteAddr: remNetSubnetBcast,
- // TODO(gvisor.dev/issue/3938): Once we support marking a route as
- // broadcast, this test should require the broadcast option to be set.
- requiresBroadcastOpt: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
- }
-
- s.SetRouteTable(test.routes)
-
- var netProto tcpip.NetworkProtocolNumber
- switch l := len(test.remoteAddr); l {
- case header.IPv4AddressSize:
- netProto = header.IPv4ProtocolNumber
- case header.IPv6AddressSize:
- netProto = header.IPv6ProtocolNumber
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
- }
- defer ep.Close()
-
- var r bytes.Reader
- data := []byte{1, 2, 3, 4}
- to := tcpip.FullAddress{
- Addr: test.remoteAddr,
- Port: 80,
- }
- opts := tcpip.WriteOptions{To: &to}
- expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error {
- if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok {
- return nil
- }
- return &tcpip.ErrBroadcastDisabled{}
- }
- if !test.requiresBroadcastOpt {
- expectedErrWithoutBcastOpt = nil
- }
-
- r.Reset(data)
- {
- n, err := ep.Write(&r, opts)
- if expectedErrWithoutBcastOpt != nil {
- if want := expectedErrWithoutBcastOpt(err); want != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
- }
- } else if err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
- }
-
- ep.SocketOptions().SetBroadcast(true)
-
- r.Reset(data)
- if n, err := ep.Write(&r, opts); err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
-
- ep.SocketOptions().SetBroadcast(false)
-
- r.Reset(data)
- {
- n, err := ep.Write(&r, opts)
- if expectedErrWithoutBcastOpt != nil {
- if want := expectedErrWithoutBcastOpt(err); want != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
- }
- } else if err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
- }
- })
- }
-}