summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD6
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD6
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go22
-rw-r--r--pkg/tcpip/buffer/BUILD6
-rw-r--r--pkg/tcpip/checker/BUILD3
-rw-r--r--pkg/tcpip/checker/checker.go50
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD6
-rw-r--r--pkg/tcpip/header/BUILD6
-rw-r--r--pkg/tcpip/header/checksum.go133
-rw-r--r--pkg/tcpip/header/checksum_test.go62
-rw-r--r--pkg/tcpip/header/icmpv6.go2
-rw-r--r--pkg/tcpip/header/ndp_options.go81
-rw-r--r--pkg/tcpip/header/ndp_test.go180
-rw-r--r--pkg/tcpip/iptables/BUILD4
-rw-r--r--pkg/tcpip/iptables/iptables.go11
-rw-r--r--pkg/tcpip/iptables/types.go15
-rw-r--r--pkg/tcpip/link/channel/BUILD3
-rw-r--r--pkg/tcpip/link/channel/channel.go43
-rw-r--r--pkg/tcpip/link/fdbased/BUILD6
-rw-r--r--pkg/tcpip/link/loopback/BUILD3
-rw-r--r--pkg/tcpip/link/muxed/BUILD6
-rw-r--r--pkg/tcpip/link/rawfile/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD6
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD6
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD6
-rw-r--r--pkg/tcpip/link/sniffer/BUILD3
-rw-r--r--pkg/tcpip/link/tun/BUILD3
-rw-r--r--pkg/tcpip/link/waitable/BUILD6
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD4
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD6
-rw-r--r--pkg/tcpip/network/hash/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/BUILD4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go3
-rw-r--r--pkg/tcpip/network/ipv6/BUILD6
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go41
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go7
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go135
-rw-r--r--pkg/tcpip/ports/BUILD6
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD2
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go2
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD2
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go2
-rw-r--r--pkg/tcpip/seqnum/BUILD3
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/ndp.go189
-rw-r--r--pkg/tcpip/stack/ndp_test.go195
-rw-r--r--pkg/tcpip/stack/nic.go239
-rw-r--r--pkg/tcpip/stack/stack_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/tcpip.go6
-rw-r--r--pkg/tcpip/transport/icmp/BUILD3
-rw-r--r--pkg/tcpip/transport/packet/BUILD3
-rw-r--r--pkg/tcpip/transport/raw/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/BUILD6
-rw-r--r--pkg/tcpip/transport/tcp/accept.go25
-rw-r--r--pkg/tcpip/transport/tcp/connect.go53
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go36
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go181
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go86
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go182
66 files changed, 1527 insertions, 642 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 23e4b09e7..26f7ba86b 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -12,7 +11,6 @@ go_library(
"time_unsafe.go",
"timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -25,7 +23,7 @@ go_test(
name = "tcpip_test",
size = "small",
srcs = ["tcpip_test.go"],
- embed = [":tcpip"],
+ library = ":tcpip",
)
go_test(
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 3df7d18d3..a984f1712 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "gonet",
srcs = ["gonet.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -23,7 +21,7 @@ go_test(
name = "gonet_test",
size = "small",
srcs = ["gonet_test.go"],
- embed = [":gonet"],
+ library = ":gonet",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index a2f44b496..711969b9b 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -556,6 +556,17 @@ type PacketConn struct {
wq *waiter.Queue
}
+// NewPacketConn creates a new PacketConn.
+func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn {
+ c := &PacketConn{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
// DialUDP creates a new PacketConn.
//
// If laddr is nil, a local address is automatically chosen.
@@ -580,12 +591,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := PacketConn{
- stack: s,
- ep: ep,
- wq: &wq,
- }
- c.deadlineTimer.init()
+ c := NewPacketConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -599,7 +605,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- return &c, nil
+ return c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
@@ -622,7 +628,7 @@ func (c *PacketConn) RemoteAddr() net.Addr {
if err != nil {
return nil
}
- return fullToTCPAddr(a)
+ return fullToUDPAddr(a)
}
// Read implements net.Conn.Read
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index d6c31bfa2..563bc78ea 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"prependable.go",
"view.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer",
visibility = ["//visibility:public"],
)
@@ -17,5 +15,5 @@ go_test(
name = "buffer_test",
size = "small",
srcs = ["view_test.go"],
- embed = [":buffer"],
+ library = ":buffer",
)
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index b6fa6fc37..ed434807f 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,7 +6,6 @@ go_library(
name = "checker",
testonly = 1,
srcs = ["checker.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/checker",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 885d773b0..4d6ae0871 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -771,6 +771,56 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
}
}
+// NDPNSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ it, err := ns.Options().Iter(true)
+ if err != nil {
+ t.Errorf("opts.Iter(true): %s", err)
+ return
+ }
+
+ i := 0
+ for {
+ opt, done, _ := it.Next()
+ if done {
+ break
+ }
+
+ if i >= len(opts) {
+ t.Errorf("got unexpected option: %s", opt)
+ continue
+ }
+
+ switch wantOpt := opts[i].(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ default:
+ panic("not implemented")
+ }
+
+ i++
+ }
+
+ if missing := opts[i:]; len(missing) > 0 {
+ t.Errorf("missing options: %s", missing)
+ }
+ }
+}
+
// NDPRS creates a checker that checks that the packet contains a valid NDP
// Router Solicitation message (as per the raw wire format).
func NDPRS() NetworkChecker {
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index e648efa71..ff2719291 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "jenkins",
srcs = ["jenkins.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins",
visibility = ["//visibility:public"],
)
@@ -16,5 +14,5 @@ go_test(
srcs = [
"jenkins_test.go",
],
- embed = [":jenkins"],
+ library = ":jenkins",
)
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index cd747d100..9da0d71f8 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -24,7 +23,6 @@ go_library(
"tcp.go",
"udp.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/header",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -59,7 +57,7 @@ go_test(
"eth_test.go",
"ndp_test.go",
],
- embed = [":header"],
+ library = ":header",
deps = [
"//pkg/tcpip",
"@com_github_google_go-cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 9749c7f4d..14a4b2b44 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -45,12 +45,139 @@ func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
+func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
+ v := initial
+
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
+ l := len(buf)
+ odd = l&1 != 0
+ if odd {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+ for (l - 64) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[64:]
+ l = l - 64
+ }
+ if (l - 32) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[32:]
+ l = l - 32
+ }
+ if (l - 16) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[16:]
+ l = l - 16
+ }
+ if (l - 8) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ buf = buf[8:]
+ l = l - 8
+ }
+ if (l - 4) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ buf = buf[4:]
+ l = l - 4
+ }
+
+ // At this point since l was even before we started unrolling
+ // there can be only two bytes left to add.
+ if l != 0 {
+ v += (uint32(buf[0]) << 8) + uint32(buf[1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
+}
+
+// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given byte array. This function uses a non-optimized implementation. Its
+// only retained for reference and to use as a benchmark/test. Most code should
+// use the header.Checksum function.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumOld(buf []byte, initial uint16) uint16 {
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
+}
+
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
-// given byte array.
+// given byte array. This function uses an optimized unrolled version of the
+// checksum algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
- s, _ := calculateChecksum(buf, false, uint32(initial))
+ s, _ := unrolledCalculateChecksum(buf, false, uint32(initial))
return s
}
@@ -86,7 +213,7 @@ func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, siz
}
v = v[:l]
- sum, odd = calculateChecksum(v, odd, uint32(sum))
+ sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum))
size -= len(v)
if size == 0 {
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 86b466c1c..309403482 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -17,6 +17,8 @@
package header_test
import (
+ "fmt"
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) {
})
}
}
+
+func TestChecksum(t *testing.T) {
+ var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
+ type testCase struct {
+ buf []byte
+ initial uint16
+ csumOrig uint16
+ csumNew uint16
+ }
+ testCases := make([]testCase, 100000)
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for i := range testCases {
+ testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
+ testCases[i].initial = uint16(rnd.Intn(65536))
+ rnd.Read(testCases[i].buf)
+ }
+
+ for i := range testCases {
+ testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
+ testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
+ if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
+ t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
+ }
+ }
+}
+
+func BenchmarkChecksum(b *testing.B) {
+ var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
+
+ checkSumImpls := []struct {
+ fn func([]byte, uint16) uint16
+ name string
+ }{
+ {header.ChecksumOld, fmt.Sprintf("checksum_old")},
+ {header.Checksum, fmt.Sprintf("checksum")},
+ }
+
+ for _, csumImpl := range checkSumImpls {
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for _, bufSz := range bufSizes {
+ b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
+ tc := struct {
+ buf []byte
+ initial uint16
+ csum uint16
+ }{
+ buf: make([]byte, bufSz),
+ initial: uint16(rnd.Intn(65536)),
+ }
+ rnd.Read(tc.buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ tc.csum = csumImpl.fn(tc.buf, tc.initial)
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index b4037b6c8..c7ee2de57 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -52,7 +52,7 @@ const (
// ICMPv6NeighborAdvertSize is size of a neighbor advertisement
// including the NDP Target Link Layer option for an Ethernet
// address.
- ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize
+ ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
// ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv6EchoMinimumSize = 8
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 06e0bace2..e6a6ad39b 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -17,6 +17,7 @@ package header
import (
"encoding/binary"
"errors"
+ "fmt"
"math"
"time"
@@ -24,13 +25,17 @@ import (
)
const (
- // NDPTargetLinkLayerAddressOptionType is the type of the Target
- // Link-Layer Address option, as per RFC 4861 section 4.6.1.
+ // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
+ NDPSourceLinkLayerAddressOptionType = 1
+
+ // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
NDPTargetLinkLayerAddressOptionType = 2
- // ndpTargetEthernetLinkLayerAddressSize is the size of a Target
- // Link Layer Option for an Ethernet address.
- ndpTargetEthernetLinkLayerAddressSize = 8
+ // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer
+ // Address option for an Ethernet address.
+ NDPLinkLayerAddressSize = 8
// NDPPrefixInformationType is the type of the Prefix Information
// option, as per RFC 4861 section 4.6.2.
@@ -189,6 +194,9 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
i.opts = i.opts[numBytes:]
switch t {
+ case NDPSourceLinkLayerAddressOptionType:
+ return NDPSourceLinkLayerAddressOption(body), false, nil
+
case NDPTargetLinkLayerAddressOptionType:
return NDPTargetLinkLayerAddressOption(body), false, nil
@@ -293,6 +301,8 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
// NDPOption is the set of functions to be implemented by all NDP option types.
type NDPOption interface {
+ fmt.Stringer
+
// Type returns the type of the receiver.
Type() uint8
@@ -368,6 +378,46 @@ func (b NDPOptionsSerializer) Length() int {
return l
}
+// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
+// as defined by RFC 4861 section 4.6.1.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
+type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
+
+// Type implements NDPOption.Type.
+func (o NDPSourceLinkLayerAddressOption) Type() uint8 {
+ return NDPSourceLinkLayerAddressOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPSourceLinkLayerAddressOption) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPSourceLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
+// EthernetAddress will return an ethernet (MAC) address if the
+// NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize
+// bytes. If the body has more than EthernetAddressSize bytes, only the first
+// EthernetAddressSize bytes are returned as that is all that is needed for an
+// Ethernet address.
+func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
+ if len(o) >= EthernetAddressSize {
+ return tcpip.LinkAddress(o[:EthernetAddressSize])
+ }
+
+ return tcpip.LinkAddress([]byte(nil))
+}
+
// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option
// as defined by RFC 4861 section 4.6.1.
//
@@ -390,6 +440,11 @@ func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
+// String implements fmt.Stringer.String.
+func (o NDPTargetLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
// EthernetAddress will return an ethernet (MAC) address if the
// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize
// bytes. If the body has more than EthernetAddressSize bytes, only the first
@@ -436,6 +491,17 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int {
return used
}
+// String implements fmt.Stringer.String.
+func (o NDPPrefixInformation) String() string {
+ return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)",
+ o,
+ o.OnLinkFlag(),
+ o.AutonomousAddressConfigurationFlag(),
+ o.PreferredLifetime(),
+ o.ValidLifetime(),
+ o.Subnet())
+}
+
// PrefixLength returns the value in the number of leading bits in the Prefix
// that are valid.
//
@@ -545,6 +611,11 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
return used
}
+// String implements fmt.Stringer.String.
+func (o NDPRecursiveDNSServer) String() string {
+ return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime())
+}
+
// Lifetime returns the length of time that the DNS server addresses
// in this option may be used for name resolution.
//
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 2c439d70c..1cb9f5dc8 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -153,6 +153,125 @@ func TestNDPRouterAdvert(t *testing.T) {
}
}
+// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the
+// Ethernet address from an NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected tcpip.LinkAddress
+ }{
+ {
+ "ValidMAC",
+ []byte{1, 2, 3, 4, 5, 6},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ "SLLBodyTooShort",
+ []byte{1, 2, 3, 4, 5},
+ tcpip.LinkAddress([]byte(nil)),
+ },
+ {
+ "SLLBodyLargerThanNeeded",
+ []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ sll := NDPSourceLinkLayerAddressOption(test.buf)
+ if got := sll.EthernetAddress(); got != test.expected {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected)
+ }
+ })
+ }
+}
+
+// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a
+// NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedBuf []byte
+ addr tcpip.LinkAddress
+ }{
+ {
+ "Ethernet",
+ make([]byte, 8),
+ []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ "\x01\x02\x03\x04\x05\x06",
+ },
+ {
+ "Padding",
+ []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ "\x01\x02\x03\x04\x05\x06\x07\x08",
+ },
+ {
+ "Empty",
+ nil,
+ nil,
+ "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ serializer := NDPOptionsSerializer{
+ NDPSourceLinkLayerAddressOption(test.addr),
+ }
+ if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
+ t.Fatalf("got Length = %d, want = %d", got, want)
+ }
+ opts.Serialize(serializer)
+ if !bytes.Equal(test.buf, test.expectedBuf) {
+ t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ if len(test.expectedBuf) > 0 {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+ sll := next.(NDPSourceLinkLayerAddressOption)
+ if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
+ t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+
+ if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+
+ // Iterator should not return anything else.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
// Ethernet address from an NDPTargetLinkLayerAddressOption.
func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
@@ -186,7 +305,6 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
}
})
}
-
}
// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
@@ -212,8 +330,8 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
},
{
"Empty",
- []byte{},
- []byte{},
+ nil,
+ nil,
"",
},
}
@@ -246,7 +364,7 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
}
tll := next.(NDPTargetLinkLayerAddressOption)
if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
@@ -254,7 +372,7 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
}
if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got tll.MACAddress = %s, want = %s", got, want)
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
}
}
@@ -510,7 +628,7 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
}
opt, ok := next.(NDPRecursiveDNSServer)
@@ -553,6 +671,16 @@ func TestNDPOptionsIterCheck(t *testing.T) {
ErrNDPOptZeroLength,
},
{
+ "ValidSourceLinkLayerAddressOption",
+ []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ nil,
+ },
+ {
+ "TooSmallSourceLinkLayerAddressOption",
+ []byte{1, 1, 1, 2, 3, 4, 5},
+ ErrNDPOptBufExhausted,
+ },
+ {
"ValidTargetLinkLayerAddressOption",
[]byte{2, 1, 1, 2, 3, 4, 5, 6},
nil,
@@ -603,10 +731,13 @@ func TestNDPOptionsIterCheck(t *testing.T) {
ErrNDPOptMalformedBody,
},
{
- "ValidTargetLinkLayerAddressWithPrefixInformation",
+ "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
[]byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
// Target Link-Layer Address.
- 2, 1, 1, 2, 3, 4, 5, 6,
+ 2, 1, 7, 8, 9, 10, 11, 12,
// Prefix information.
3, 4, 43, 64,
@@ -621,10 +752,13 @@ func TestNDPOptionsIterCheck(t *testing.T) {
nil,
},
{
- "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
+ "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
[]byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
// Target Link-Layer Address.
- 2, 1, 1, 2, 3, 4, 5, 6,
+ 2, 1, 7, 8, 9, 10, 11, 12,
// 255 is an unrecognized type. If 255 ends up
// being the type for some recognized type,
@@ -714,8 +848,11 @@ func TestNDPOptionsIterCheck(t *testing.T) {
// here.
func TestNDPOptionsIter(t *testing.T) {
buf := []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
// Target Link-Layer Address.
- 2, 1, 1, 2, 3, 4, 5, 6,
+ 2, 1, 7, 8, 9, 10, 11, 12,
// 255 is an unrecognized type. If 255 ends up being the type
// for some recognized type, update 255 to some other
@@ -740,7 +877,7 @@ func TestNDPOptionsIter(t *testing.T) {
t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
}
- // Test the first (Taret Link-Layer) option.
+ // Test the first (Source Link-Layer) option.
next, done, err := it.Next()
if err != nil {
t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
@@ -748,7 +885,22 @@ func TestNDPOptionsIter(t *testing.T) {
if done {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
- if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
+ if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+
+ // Test the next (Target Link-Layer) option.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
@@ -764,7 +916,7 @@ func TestNDPOptionsIter(t *testing.T) {
if done {
t.Fatal("got Next = (_, true, _), want = (_, false, _)")
}
- if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) {
+ if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
}
if got := next.Type(); got != NDPPrefixInformationType {
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index 2893c80cd..d1b73cfdf 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,10 +9,10 @@ go_library(
"targets.go",
"types.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
],
)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index 605a71679..4bfb3149e 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
// Table names.
@@ -184,8 +185,16 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri
panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename))
}
+// Precondition: pk.NetworkHeader is set.
func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict {
rule := table.Rules[ruleIdx]
+
+ // First check whether the packet matches the IP header filter.
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
+ return Continue
+ }
+
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 9f6906100..50893cc55 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -14,7 +14,9 @@
package iptables
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
// A Hook specifies one of the hooks built into the network stack.
//
@@ -151,6 +153,9 @@ func (table *Table) SetMetadata(metadata interface{}) {
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
type Rule struct {
+ // Filter holds basic IP filtering fields common to every rule.
+ Filter IPHeaderFilter
+
// Matchers is the list of matchers for this rule.
Matchers []Matcher
@@ -158,6 +163,12 @@ type Rule struct {
Target Target
}
+// IPHeaderFilter holds basic IP filtering data common to every rule.
+type IPHeaderFilter struct {
+ // Protocol matches the transport protocol.
+ Protocol tcpip.TransportProtocolNumber
+}
+
// A Matcher is the interface for matching packets.
type Matcher interface {
// Match returns whether the packet matches and whether the packet
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 7dbc05754..3974c464e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "channel",
srcs = ["channel.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 70188551f..71b9da797 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -18,6 +18,8 @@
package channel
import (
+ "context"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -38,25 +40,52 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
GSO bool
- // C is where outbound packets are queued.
- C chan PacketInfo
+ // c is where outbound packets are queued.
+ c chan PacketInfo
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- C: make(chan PacketInfo, size),
+ c: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
}
+// Close closes e. Further packet injections will panic. Reads continue to
+// succeed until all packets are read.
+func (e *Endpoint) Close() {
+ close(e.c)
+}
+
+// Read does non-blocking read for one packet from the outbound packet queue.
+func (e *Endpoint) Read() (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+// ReadContext does blocking read for one packet from the outbound packet queue.
+// It can be cancelled by ctx, and in this case, it returns false.
+func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
select {
- case <-e.C:
+ case <-e.c:
c++
default:
return c
@@ -125,7 +154,7 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
@@ -150,7 +179,7 @@ packetLoop:
}
select {
- case e.C <- p:
+ case e.c <- p:
n++
default:
break packetLoop
@@ -169,7 +198,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 66cc53ed4..abe725548 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -13,7 +12,6 @@ go_library(
"mmap_unsafe.go",
"packet_dispatchers.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -30,7 +28,7 @@ go_test(
name = "fdbased_test",
size = "small",
srcs = ["endpoint_test.go"],
- embed = [":fdbased"],
+ library = ":fdbased",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
index f35fcdff4..6bf3805b7 100644
--- a/pkg/tcpip/link/loopback/BUILD
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "loopback",
srcs = ["loopback.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 1ac7948b6..82b441b79 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "muxed",
srcs = ["injectable.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -19,7 +17,7 @@ go_test(
name = "muxed_test",
size = "small",
srcs = ["injectable_test.go"],
- embed = [":muxed"],
+ library = ":muxed",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index d8211e93d..14b527bc2 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -12,7 +12,6 @@ go_library(
"errors.go",
"rawfile_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 09165dd4c..13243ebbb 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,7 +10,6 @@ go_library(
"sharedmem_unsafe.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -30,7 +28,7 @@ go_test(
srcs = [
"sharedmem_test.go",
],
- embed = [":sharedmem"],
+ library = ":sharedmem",
deps = [
"//pkg/sync",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index a0d4ad0be..87020ec08 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,7 +10,6 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe",
visibility = ["//visibility:public"],
)
@@ -20,6 +18,6 @@ go_test(
srcs = [
"pipe_test.go",
],
- embed = [":pipe"],
+ library = ":pipe",
deps = ["//pkg/sync"],
)
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index 8c9234d54..3ba06af73 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -22,7 +20,7 @@ go_test(
srcs = [
"queue_test.go",
],
- embed = [":queue"],
+ library = ":queue",
deps = [
"//pkg/tcpip/link/sharedmem/pipe",
],
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index d6ae0368a..230a8d53a 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,7 +8,6 @@ go_library(
"pcap.go",
"sniffer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index a71a493fc..e5096ea38 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -1,10 +1,9 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "tun",
srcs = ["tun_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun",
visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 134837943..0956d2c65 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,7 +7,6 @@ go_library(
srcs = [
"waitable.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable",
visibility = ["//visibility:public"],
deps = [
"//pkg/gate",
@@ -23,7 +21,7 @@ go_test(
srcs = [
"waitable_test.go",
],
- embed = [":waitable"],
+ library = ":waitable",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 9d16ff8c9..6a4839fb8 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index e7617229b..eddf7b725 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "arp",
srcs = ["arp.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 8e6048a21..03cf03b6d 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,6 +15,7 @@
package arp_test
import (
+ "context"
"strconv"
"testing"
"time"
@@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP.C)
+ c.linkEP.Close()
}
func TestDirectRequest(t *testing.T) {
@@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) {
for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
- pi := <-c.linkEP.C
+ pi, _ := c.linkEP.ReadContext(context.Background())
if pi.Proto != arp.ProtocolNumber {
t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
@@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) {
}
inject(stackAddrBad)
- select {
- case pkt := <-c.linkEP.C:
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
+ ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
- case <-time.After(100 * time.Millisecond):
- // Sleep tests are gross, but this will only potentially flake
- // if there's a bug. If there is no bug this will reliably
- // succeed.
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index ed16076fd..d1c728ccf 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -24,7 +23,6 @@ go_library(
"reassembler.go",
"reassembler_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -42,6 +40,6 @@ go_test(
"fragmentation_test.go",
"reassembler_test.go",
],
- embed = [":fragmentation"],
+ library = ":fragmentation",
deps = ["//pkg/tcpip/buffer"],
)
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
index e6db5c0b0..872165866 100644
--- a/pkg/tcpip/network/hash/BUILD
+++ b/pkg/tcpip/network/hash/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "hash",
srcs = ["hash.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash",
visibility = ["//visibility:public"],
deps = [
"//pkg/rand",
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 4e2aae9a3..0fef2b1f1 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"icmp.go",
"ipv4.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 0a1453b31..85512f9b2 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -353,7 +353,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
}
pkt.NetworkHeader = headerView[:h.HeaderLength()]
- // iptables filtering.
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
ipt := e.stack.IPTables()
if ok := ipt.Check(iptables.Input, pkt); !ok {
// iptables is telling us to drop the packet.
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index e4e273460..fb11874c6 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"icmp.go",
"ipv6.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -27,7 +25,7 @@ go_test(
"ipv6_test.go",
"ndp_test.go",
],
- embed = [":ipv6"],
+ library = ":ipv6",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 1c3410618..dc20c0fd7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -137,21 +137,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
}
ns := header.NDPNeighborSolicit(h.NDPPayload())
+ it, err := ns.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NS option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
targetAddr := ns.TargetAddress()
s := r.Stack()
rxNICID := r.NICID()
-
- isTentative, err := s.IsAddrTentative(rxNICID, targetAddr)
- if err != nil {
+ if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil {
// We will only get an error if rxNICID is unrecognized,
// which should not happen. For now short-circuit this
// packet.
//
// TODO(b/141002840): Handle this better?
return
- }
-
- if isTentative {
+ } else if isTentative {
// If the target address is tentative and the source
// of the packet is a unicast (specified) address, then
// the source of the packet is attempting to perform
@@ -185,6 +188,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
return
}
+ // If the NS message has the source link layer option, update the link
+ // address cache with the link address for the sender of the message.
+ //
+ // TODO(b/148429853): Properly process the NS message and do Neighbor
+ // Unreachability Detection.
+ for {
+ opt, done, _ := it.Next()
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress())
+ }
+ }
+
optsSerializer := header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
}
@@ -211,15 +231,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
r.LocalAddress = targetAddr
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- // TODO(tamird/ghanan): there exists an explicit NDP option that is
- // used to update the neighbor table with link addresses for a
- // neighbor from an NS (see the Source Link Layer option RFC
- // 4861 section 4.6.1 and section 7.2.3).
- //
- // Furthermore, the entirety of NDP handling here seems to be
- // contradicted by RFC 4861.
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
-
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
//
// 7.1.2. Validation of Neighbor Advertisements
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index a2fdc5dcd..7a6820643 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "context"
"reflect"
"strings"
"testing"
@@ -264,8 +265,8 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP0.C)
- close(c.linkEP1.C)
+ c.linkEP0.Close()
+ c.linkEP1.Close()
}
type routeArgs struct {
@@ -276,7 +277,7 @@ type routeArgs struct {
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pi := <-args.src.C
+ pi, _ := args.src.ReadContext(context.Background())
{
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index fe895b376..bd732f93f 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -70,6 +70,141 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
return s, ep
}
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an
+// NDP NS message with the Source Link Layer Address option results in a
+// new entry in the link address cache for the sender of the message.
+func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+ if linkAddr != linkAddr1 {
+ t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1)
+ }
+}
+
+// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving
+// an NDP NS message with an invalid Source Link Layer Address option does not
+// result in a new entry in the link address cache for the sender of the
+// message.
+func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ }{
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 1, 2, 3, 4, 5},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+ if linkAddr != "" {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr)
+ }
+ })
+ }
+}
+
// TestHopLimitValidation is a test that makes sure that NDP packets are only
// received if their IP header's hop limit is set to 255.
func TestHopLimitValidation(t *testing.T) {
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index a6ef3bdcc..2bad05a2e 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,12 +1,10 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "ports",
srcs = ["ports.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/ports",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -17,7 +15,7 @@ go_library(
go_test(
name = "ports_test",
srcs = ["ports_test.go"],
- embed = [":ports"],
+ library = ":ports",
deps = [
"//pkg/tcpip",
],
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index d7496fde6..cf0a5fefe 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 2239c1e66..0ab089208 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -164,7 +164,7 @@ func main() {
// Create TCP endpoint.
var wq waiter.Queue
ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
+ if e != nil {
log.Fatal(e)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
index 875561566..43264b76d 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index bca73cbb1..9e37cab18 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -168,7 +168,7 @@ func main() {
// Create TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
- if err != nil {
+ if e != nil {
log.Fatal(e)
}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
index b31ddba2f..45f503845 100644
--- a/pkg/tcpip/seqnum/BUILD
+++ b/pkg/tcpip/seqnum/BUILD
@@ -1,10 +1,9 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "seqnum",
srcs = ["seqnum.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum",
visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 783351a69..f5b750046 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -30,7 +29,6 @@ go_library(
"stack_global_state.go",
"transport_demuxer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/stack",
visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
@@ -81,7 +79,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = ["linkaddrcache_test.go"],
- embed = [":stack"],
+ library = ":stack",
deps = [
"//pkg/sleep",
"//pkg/sync",
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 7d4b41dfa..31294345d 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -15,7 +15,6 @@
package stack
import (
- "fmt"
"log"
"math/rand"
"time"
@@ -168,8 +167,8 @@ type NDPDispatcher interface {
// reason, such as the address being removed). If an error occured
// during DAD, err will be set and resolved must be ignored.
//
- // This function is permitted to block indefinitely without interfering
- // with the stack's operation.
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
// OnDefaultRouterDiscovered will be called when a new default router is
@@ -429,8 +428,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
return tcpip.ErrAddressFamilyNotSupported
}
- // Should not attempt to perform DAD on an address that is currently in
- // the DAD process.
+ if ref.getKind() != permanentTentative {
+ // The endpoint should be marked as tentative since we are starting DAD.
+ log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in the
+ // DAD process.
if _, ok := ndp.dad[addr]; ok {
// Should never happen because we should only ever call this function for
// newly created addresses. If we attemped to "add" an address that already
@@ -438,77 +442,79 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
// See NIC.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())
}
remaining := ndp.configs.DupAddrDetectTransmits
-
- {
- done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
- if err != nil {
- return err
- }
- if done {
- return nil
- }
+ if remaining == 0 {
+ ref.setKind(permanent)
+ return nil
}
- remaining--
-
var done bool
var timer *time.Timer
- timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() {
- var d bool
- var err *tcpip.Error
-
- // doDadIteration does a single iteration of the DAD loop.
- //
- // Returns true if the integrator needs to be informed of DAD
- // completing.
- doDadIteration := func() bool {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD
- // timer fired after another goroutine already
- // obtained the NIC lock and stopped DAD before
- // this function obtained the NIC lock. Simply
- // return here and do nothing further.
- return false
- }
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the NIC's lock. This is effectively the same
+ // as starting a goroutine but we use a timer that fires immediately so we can
+ // reset it for the next DAD iteration.
+ timer = time.AfterFunc(0, func() {
+ ndp.nic.mu.RLock()
+ if done {
+ // If we reach this point, it means that the DAD timer fired after
+ // another goroutine already obtained the NIC lock and stopped DAD
+ // before this function obtained the NIC lock. Simply return here and do
+ // nothing further.
+ ndp.nic.mu.RUnlock()
+ return
+ }
- ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
- if !ok {
- // This should never happen.
- // We should have an endpoint for addr since we
- // are still performing DAD on it. If the
- // endpoint does not exist, but we are doing DAD
- // on it, then we started DAD at some point, but
- // forgot to stop it when the endpoint was
- // deleted.
- panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
- }
+ if ref.getKind() != permanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())
+ }
- d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
- if err != nil || d {
- delete(ndp.dad, addr)
+ dadDone := remaining == 0
+ ndp.nic.mu.RUnlock()
- if err != nil {
- log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
- }
+ var err *tcpip.Error
+ if !dadDone {
+ err = ndp.sendDADPacket(addr)
+ }
- // Let the integrator know DAD has completed.
- return true
- }
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that DAD was stopped after we released
+ // the NIC's read lock and before we obtained the write lock.
+ ndp.nic.mu.Unlock()
+ return
+ }
+ if dadDone {
+ // DAD has resolved.
+ ref.setKind(permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
remaining--
timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
- return false
+
+ ndp.nic.mu.Unlock()
+ return
}
- if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
- ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ // At this point we know that either DAD is done or we hit an error sending
+ // the last NDP NS. Either way, clean up addr's DAD state and let the
+ // integrator know DAD has completed.
+ delete(ndp.dad, addr)
+ ndp.nic.mu.Unlock()
+
+ if err != nil {
+ log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
+ }
+
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
}
})
@@ -520,45 +526,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
return nil
}
-// doDuplicateAddressDetection is called on every iteration of the timer, and
-// when DAD starts.
-//
-// It handles resolving the address (if there are no more NS to send), or
-// sending the next NS if there are more NS to send.
-//
-// This function must only be called by IPv6 addresses that are currently
-// tentative.
-//
-// The NIC that ndp belongs to (n) MUST be locked.
+// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
+// addr.
//
-// Returns true if DAD has resolved; false if DAD is still ongoing.
-func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
- if ref.getKind() != permanentTentative {
- // The endpoint should still be marked as tentative
- // since we are still performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
- }
-
- if remaining == 0 {
- // DAD has resolved.
- ref.setKind(permanent)
- return true, nil
- }
-
- // Send a new NS.
+// addr must be a tentative IPv6 address on ndp's NIC.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
- if !ok {
- // This should never happen as if we have the
- // address, we should have the solicited-node
- // address.
- panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
- }
- snmcRef.incRef()
- // Use the unspecified address as the source address when performing
- // DAD.
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
@@ -569,15 +546,19 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- }); err != nil {
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, tcpip.PacketBuffer{Header: hdr},
+ ); err != nil {
sent.Dropped.Increment()
- return false, err
+ return err
}
sent.NeighborSolicit.Increment()
- return false, nil
+ return nil
}
// stopDuplicateAddressDetection ends a running Duplicate Address Detection
@@ -608,8 +589,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndp.nic.stack.ndpDisp != nil {
- go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
}
}
@@ -1212,7 +1193,7 @@ func (ndp *ndpState) startSolicitingRouters() {
ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
// Send an RS message with the unspecified source address.
- ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true)
+ ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 1a52e0e68..bc7cfbcb4 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "context"
"encoding/binary"
"fmt"
"testing"
@@ -35,13 +36,14 @@ import (
)
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- defaultTimeout = 100 * time.Millisecond
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ defaultTimeout = 100 * time.Millisecond
+ defaultAsyncEventTimeout = time.Second
)
var (
@@ -301,6 +303,8 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
func TestDADResolve(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
dupAddrDetectTransmits uint8
@@ -331,44 +335,36 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
- }
-
- stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
-
- // Should have sent an NDP NS immediately.
- if got := stat.Value(); got != 1 {
- t.Fatalf("got NeighborSolicit = %d, want = 1", got)
-
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
// Address should not be considered bound to the NIC yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Wait for the remaining time - some delta (500ms), to
// make sure the address is still not resolved.
const delta = 500 * time.Millisecond
time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Wait for DAD to resolve.
@@ -385,8 +381,8 @@ func TestDADResolve(t *testing.T) {
if e.err != nil {
t.Fatal("got DAD error: ", e.err)
}
- if e.nicID != 1 {
- t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ if e.nicID != nicID {
+ t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID)
}
if e.addr != addr1 {
t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
@@ -395,37 +391,44 @@ func TestDADResolve(t *testing.T) {
t.Fatal("got DAD event w/ resolved = false, want = true")
}
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
}
// Should not have sent any more NS messages.
- if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p := <-e.C
+ p, _ := e.ReadContext(context.Background())
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
- // Check NDP packet.
+ // Check NDP NS packet.
+ //
+ // As per RFC 4861 section 4.3, a possible option is the Source Link
+ // Layer option, but this option MUST NOT be included when the source
+ // address of the packet is the unspecified address.
checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.SolicitedNodeAddr(addr1)),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
- checker.NDPNSTargetAddress(addr1)))
+ checker.NDPNSTargetAddress(addr1),
+ checker.NDPNSOptions(nil),
+ ))
}
})
}
-
}
// TestDADFail tests to make sure that the DAD process fails if another node is
@@ -498,7 +501,7 @@ func TestDADFail(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
ndpConfigs := stack.DefaultNDPConfigurations()
opts := stack.Options{
@@ -577,7 +580,7 @@ func TestDADFail(t *testing.T) {
// removed.
func TestDADStop(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
ndpConfigs := stack.NDPConfigurations{
RetransmitTimer: time.Second,
@@ -1093,7 +1096,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1110,7 +1113,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1349,7 +1352,7 @@ func TestPrefixDiscovery(t *testing.T) {
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout):
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1688,7 +1691,7 @@ func TestAutoGenAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1994,7 +1997,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2034,7 +2037,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2048,7 +2051,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2080,7 +2083,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
@@ -2095,7 +2098,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -2220,7 +2223,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(minVLSeconds*time.Second + defaultTimeout):
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -2708,7 +2711,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout):
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3291,29 +3294,29 @@ func TestRouterSolicitation(t *testing.T) {
e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t,
- p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(),
- )
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
}
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(),
+ )
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet")
- case <-time.After(timeout):
}
}
s := stack.New(stack.Options{
@@ -3332,12 +3335,12 @@ func TestRouterSolicitation(t *testing.T) {
// times.
remaining := test.maxRtrSolicit
if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
remaining--
}
for ; remaining > 0; remaining-- {
waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
- waitForPkt(2 * defaultTimeout)
+ waitForPkt(defaultAsyncEventTimeout)
}
// Make sure no more RS.
@@ -3368,20 +3371,21 @@ func TestStopStartSolicitingRouters(t *testing.T) {
e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
waitForPkt := func(timeout time.Duration) {
t.Helper()
- select {
- case p := <-e.C:
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
-
- case <-time.After(timeout):
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
t.Fatal("timed out waiting for packet")
+ return
}
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
@@ -3397,41 +3401,36 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Enable forwarding which should stop router solicitations.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before forwarding was enabled.
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok = e.ReadContext(ctx); ok {
t.Fatal("Should not have sent more than one RS message")
- case <-time.After(interval + defaultTimeout):
}
- case <-time.After(delay + defaultTimeout):
}
// Enabling forwarding again should do nothing.
s.SetForwarding(true)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
// Disable forwarding which should start router solicitations.
s.SetForwarding(false)
- waitForPkt(delay + defaultTimeout)
- waitForPkt(interval + defaultTimeout)
- waitForPkt(interval + defaultTimeout)
- select {
- case <-e.C:
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
- case <-time.After(interval + defaultTimeout):
}
// Disabling forwarding again should do nothing.
s.SetForwarding(false)
- select {
- case <-e.C:
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after becoming a router")
- case <-time.After(delay + defaultTimeout):
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index de88c0bfa..7dad9a8cb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -35,24 +35,21 @@ type NIC struct {
linkEP LinkEndpoint
context NICContext
- mu sync.RWMutex
- spoofing bool
- promiscuous bool
- primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- addressRanges []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
-
stats NICStats
- // ndp is the NDP related state for NIC.
- //
- // Note, read and write operations on ndp require that the NIC is
- // appropriately locked.
- ndp ndpState
+ mu struct {
+ sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ addressRanges []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ ndp ndpState
+ }
}
// NICStats includes transmitted and received stats.
@@ -97,15 +94,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- context: ctx,
- primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
- endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
- mcastJoins: make(map[NetworkEndpointID]int32),
- packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ context: ctx,
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -116,22 +109,26 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
Bytes: &tcpip.StatCounter{},
},
},
- ndp: ndpState{
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
- },
}
- nic.ndp.nic = nic
+ nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
+ nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
+ nic.mu.mcastJoins = make(map[NetworkEndpointID]int32)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.ndp = ndpState{
+ nic: nic,
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
+ }
// Register supported packet endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = []PacketEndpoint{}
}
for _, netProto := range stack.networkProtocols {
- nic.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
return nic
@@ -215,7 +212,7 @@ func (n *NIC) enable() *tcpip.Error {
// and default routers). Therefore, soliciting RAs from other routers on
// a link is unnecessary for routers.
if !n.stack.forwarding {
- n.ndp.startSolicitingRouters()
+ n.mu.ndp.startSolicitingRouters()
}
return nil
@@ -230,8 +227,8 @@ func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
- n.ndp.cleanupHostOnlyState()
- n.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupHostOnlyState()
+ n.mu.ndp.stopSolicitingRouters()
}
// becomeIPv6Host transitions n into an IPv6 host.
@@ -242,7 +239,7 @@ func (n *NIC) becomeIPv6Host() {
n.mu.Lock()
defer n.mu.Unlock()
- n.ndp.startSolicitingRouters()
+ n.mu.ndp.startSolicitingRouters()
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -254,13 +251,13 @@ func (n *NIC) attachLinkEndpoint() {
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
- n.promiscuous = enable
+ n.mu.promiscuous = enable
n.mu.Unlock()
}
func (n *NIC) isPromiscuousMode() bool {
n.mu.RLock()
- rv := n.promiscuous
+ rv := n.mu.promiscuous
n.mu.RUnlock()
return rv
}
@@ -272,7 +269,7 @@ func (n *NIC) isLoopback() bool {
// setSpoofing enables or disables address spoofing.
func (n *NIC) setSpoofing(enable bool) {
n.mu.Lock()
- n.spoofing = enable
+ n.mu.spoofing = enable
n.mu.Unlock()
}
@@ -291,8 +288,8 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr t
defer n.mu.RUnlock()
var deprecatedEndpoint *referencedNetworkEndpoint
- for _, r := range n.primary[protocol] {
- if !r.isValidForOutgoing() {
+ for _, r := range n.mu.primary[protocol] {
+ if !r.isValidForOutgoingRLocked() {
continue
}
@@ -342,7 +339,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
n.mu.RLock()
defer n.mu.RUnlock()
- primaryAddrs := n.primary[header.IPv6ProtocolNumber]
+ primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
if len(primaryAddrs) == 0 {
return nil
@@ -425,7 +422,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
// hasPermanentAddrLocked returns true if n has a permanent (including currently
// tentative) address, addr.
func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return false
@@ -436,24 +433,54 @@ func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
return kind == permanent || kind == permanentTentative
}
+type getRefBehaviour int
+
+const (
+ // spoofing indicates that the NIC's spoofing flag should be observed when
+ // getting a NIC's referenced network endpoint.
+ spoofing getRefBehaviour = iota
+
+ // promiscuous indicates that the NIC's promiscuous flag should be observed
+ // when getting a NIC's referenced network endpoint.
+ promiscuous
+
+ // forceSpoofing indicates that the NIC should be assumed to be spoofing,
+ // regardless of what the NIC's spoofing flag is when getting a NIC's
+ // referenced network endpoint.
+ forceSpoofing
+)
+
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
+ return n.getRefOrCreateTemp(protocol, address, peb, spoofing)
}
// getRefEpOrCreateTemp returns the referenced network endpoint for the given
-// protocol and address. If none exists a temporary one may be created if
-// we are in promiscuous mode or spoofing.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
+// protocol and address.
+//
+// If none exists a temporary one may be created if we are in promiscuous mode
+// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
+// Similarly, spoofing will only be checked if spoofing is true.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok {
+ var spoofingOrPromiscuous bool
+ switch tempRef {
+ case spoofing:
+ spoofingOrPromiscuous = n.mu.spoofing
+ case promiscuous:
+ spoofingOrPromiscuous = n.mu.promiscuous
+ case forceSpoofing:
+ spoofingOrPromiscuous = true
+ }
+
+ if ref, ok := n.mu.endpoints[id]; ok {
// An endpoint with this id exists, check if it can be used and return it.
switch ref.getKind() {
case permanentExpired:
@@ -474,7 +501,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// the caller or if the address is found in the NIC's subnets.
createTempEP := spoofingOrPromiscuous
if !createTempEP {
- for _, sn := range n.addressRanges {
+ for _, sn := range n.mu.addressRanges {
// Skip the subnet address.
if address == sn.ID() {
continue
@@ -502,7 +529,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok {
+ if ref, ok := n.mu.endpoints[id]; ok {
// No need to check the type as we are ok with expired endpoints at this
// point.
if ref.tryIncRef() {
@@ -543,7 +570,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
// Sanity check.
id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address}
- if ref, ok := n.endpoints[id]; ok {
+ if ref, ok := n.mu.endpoints[id]; ok {
// Endpoint already exists.
if kind != permanent {
return nil, tcpip.ErrDuplicateAddress
@@ -562,7 +589,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
ref.deprecated = deprecated
ref.configType = configType
- refs := n.primary[ref.protocol]
+ refs := n.mu.primary[ref.protocol]
for i, r := range refs {
if r == ref {
switch peb {
@@ -572,9 +599,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if i == 0 {
return ref, nil
}
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
case NeverPrimaryEndpoint:
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
return ref, nil
}
}
@@ -637,13 +664,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
- n.endpoints[id] = ref
+ n.mu.endpoints[id] = ref
n.insertPrimaryEndpointLocked(ref, peb)
// If we are adding a tentative IPv6 address, start DAD.
if isIPv6Unicast && kind == permanentTentative {
- if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
}
@@ -668,8 +695,8 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
- addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
- for nid, ref := range n.endpoints {
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
+ for nid, ref := range n.mu.endpoints {
// Don't include tentative, expired or temporary endpoints to
// avoid confusion and prevent the caller from using those.
switch ref.getKind() {
@@ -695,7 +722,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
defer n.mu.RUnlock()
var addrs []tcpip.ProtocolAddress
- for proto, list := range n.primary {
+ for proto, list := range n.mu.primary {
for _, ref := range list {
// Don't include tentative, expired or tempory endpoints
// to avoid confusion and prevent the caller from using
@@ -726,7 +753,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
n.mu.RLock()
defer n.mu.RUnlock()
- list, ok := n.primary[proto]
+ list, ok := n.mu.primary[proto]
if !ok {
return tcpip.AddressWithPrefix{}
}
@@ -769,7 +796,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
// address.
func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
n.mu.Lock()
- n.addressRanges = append(n.addressRanges, subnet)
+ n.mu.addressRanges = append(n.mu.addressRanges, subnet)
n.mu.Unlock()
}
@@ -778,13 +805,13 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
n.mu.Lock()
// Use the same underlying array.
- tmp := n.addressRanges[:0]
- for _, sub := range n.addressRanges {
+ tmp := n.mu.addressRanges[:0]
+ for _, sub := range n.mu.addressRanges {
if sub != subnet {
tmp = append(tmp, sub)
}
}
- n.addressRanges = tmp
+ n.mu.addressRanges = tmp
n.mu.Unlock()
}
@@ -793,8 +820,8 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
func (n *NIC) AddressRanges() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
- sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints))
- for nid := range n.endpoints {
+ sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints))
+ for nid := range n.mu.endpoints {
sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
if err != nil {
// This should never happen as the mask has been carefully crafted to
@@ -803,7 +830,7 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
}
sns = append(sns, sn)
}
- return append(sns, n.addressRanges...)
+ return append(sns, n.mu.addressRanges...)
}
// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
@@ -813,9 +840,9 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
switch peb {
case CanBePrimaryEndpoint:
- n.primary[r.protocol] = append(n.primary[r.protocol], r)
+ n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r)
case FirstPrimaryEndpoint:
- n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...)
+ n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...)
}
}
@@ -827,7 +854,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
// and was waiting (on the lock) to be removed and 2) the same address was
// re-added in the meantime by removing this endpoint from the list and
// adding a new one.
- if n.endpoints[id] != r {
+ if n.mu.endpoints[id] != r {
return
}
@@ -835,11 +862,11 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
panic("Reference count dropped to zero before being removed")
}
- delete(n.endpoints, id)
- refs := n.primary[r.protocol]
+ delete(n.mu.endpoints, id)
+ refs := n.mu.primary[r.protocol]
for i, ref := range refs {
if ref == r {
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
break
}
}
@@ -854,7 +881,7 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
}
func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
- r, ok := n.endpoints[NetworkEndpointID{addr}]
+ r, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return tcpip.ErrBadLocalAddress
}
@@ -870,13 +897,13 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing a tentative IPv6 unicast address, stop
// DAD.
if kind == permanentTentative {
- n.ndp.stopDuplicateAddressDetection(addr)
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
}
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
if r.configType == slaac {
- n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
+ n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
}
}
@@ -926,7 +953,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A
// outlined in RFC 3810 section 5.
id := NetworkEndpointID{addr}
- joins := n.mcastJoins[id]
+ joins := n.mu.mcastJoins[id]
if joins == 0 {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
@@ -942,7 +969,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A
return err
}
}
- n.mcastJoins[id] = joins + 1
+ n.mu.mcastJoins[id] = joins + 1
return nil
}
@@ -960,7 +987,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
// before leaveGroupLocked is called.
func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
- joins := n.mcastJoins[id]
+ joins := n.mu.mcastJoins[id]
switch joins {
case 0:
// There are no joins with this address on this NIC.
@@ -971,7 +998,7 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return err
}
}
- n.mcastJoins[id] = joins - 1
+ n.mu.mcastJoins[id] = joins - 1
return nil
}
@@ -1006,12 +1033,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// Are any packet sockets listening for this network protocol?
n.mu.RLock()
- packetEPs := n.packetEPs[protocol]
+ packetEPs := n.mu.packetEPs[protocol]
// Check whether there are packet sockets listening for every protocol.
// If we received a packet with protocol EthernetProtocolAll, then the
// previous for loop will have handled it.
if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...)
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
}
n.mu.RUnlock()
for _, ep := range packetEPs {
@@ -1060,8 +1087,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// Found a NIC.
n := r.ref.nic
n.mu.RLock()
- ref, ok := n.endpoints[NetworkEndpointID{dst}]
- ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
+ ref, ok := n.mu.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
n.mu.RUnlock()
if ok {
r.RemoteAddress = src
@@ -1181,7 +1208,10 @@ func (n *NIC) Stack() *Stack {
// false. It will only return true if the address is associated with the NIC
// AND it is tentative.
func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return false
}
@@ -1197,7 +1227,7 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return tcpip.ErrBadAddress
}
@@ -1217,7 +1247,7 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) {
c.validate()
n.mu.Lock()
- n.ndp.configs = c
+ n.mu.ndp.configs = c
n.mu.Unlock()
}
@@ -1226,7 +1256,7 @@ func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
n.mu.Lock()
defer n.mu.Unlock()
- n.ndp.handleRA(ip, ra)
+ n.mu.ndp.handleRA(ip, ra)
}
type networkEndpointKind int32
@@ -1268,11 +1298,11 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
n.mu.Lock()
defer n.mu.Unlock()
- eps, ok := n.packetEPs[netProto]
+ eps, ok := n.mu.packetEPs[netProto]
if !ok {
return tcpip.ErrNotSupported
}
- n.packetEPs[netProto] = append(eps, ep)
+ n.mu.packetEPs[netProto] = append(eps, ep)
return nil
}
@@ -1281,14 +1311,14 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
n.mu.Lock()
defer n.mu.Unlock()
- eps, ok := n.packetEPs[netProto]
+ eps, ok := n.mu.packetEPs[netProto]
if !ok {
return
}
for i, epOther := range eps {
if epOther == ep {
- n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
return
}
}
@@ -1346,14 +1376,19 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed), or the NIC to be in spoofing mode.
func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
- return r.getKind() != permanentExpired || r.nic.spoofing
+ r.nic.mu.RLock()
+ defer r.nic.mu.RUnlock()
+
+ return r.isValidForOutgoingRLocked()
}
-// isValidForIncoming returns true if the endpoint can accept an incoming
-// packet. It requires the endpoint to not be marked expired (i.e., its address
-// has been removed), or the NIC to be in promiscuous mode.
-func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
- return r.getKind() != permanentExpired || r.nic.promiscuous
+// isValidForOutgoingRLocked returns true if the endpoint can be used to send
+// out a packet. It requires the endpoint to not be marked expired (i.e., its
+// address has been removed), or the NIC to be in spoofing mode.
+//
+// r's NIC must be read locked.
+func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
+ return r.getKind() != permanentExpired || r.nic.mu.spoofing
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dad288642..834fe9487 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) {
Data: buf.ToVectorisedView(),
})
- select {
- case <-ep2.C:
- default:
+ if _, ok := ep2.Read(); !ok {
t.Fatal("Packet not forwarded")
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index f50604a8a..869c69a6d 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) {
t.Fatalf("Write failed: %v", err)
}
- var p channel.PacketInfo
- select {
- case p = <-ep2.C:
- default:
+ p, ok := ep2.Read()
+ if !ok {
t.Fatal("Response packet not forwarded")
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 59c9b3fb0..0fa141d58 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -626,6 +626,12 @@ type TCPLingerTimeoutOption time.Duration
// before being marked closed.
type TCPTimeWaitTimeoutOption time.Duration
+// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
+// accept to return a completed connection only when there is data to be
+// read. This usually means the listening socket will drop the final ACK
+// for a handshake till the specified timeout until a segment with data arrives.
+type TCPDeferAcceptOption time.Duration
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 3aa23d529..ac18ec5b1 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,7 +23,6 @@ go_library(
"icmp_packet_list.go",
"protocol.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index 4858d150c..d22de6b26 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -22,7 +22,6 @@ go_library(
"endpoint_state.go",
"packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2f2131ff7..c9baf4600 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,7 +23,6 @@ go_library(
"protocol.go",
"raw_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 0e3ab05ad..272e8f570 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -55,10 +54,10 @@ go_library(
"tcp_segment_list.go",
"timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
@@ -92,6 +91,7 @@ go_test(
tags = ["flaky"],
deps = [
":tcp",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d469758eb..6101f2945 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
- n := newEndpoint(l.stack, netProto, nil)
+ n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6only
n.ID = s.id
n.boundNICID = s.route.NICID()
@@ -273,16 +273,17 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// createEndpoint creates a new endpoint in connected state and then performs
// the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
if err != nil {
return nil, err
}
// listenEP is nil when listenContext is used by tcp.Forwarder.
+ deferAccept := time.Duration(0)
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
@@ -290,13 +291,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+ deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
// Perform the 3-way handshake.
- h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
-
- h.resetToSynRcvd(isn, irs, opts)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
if l.listenEP != nil {
@@ -377,16 +377,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer e.decSynRcvdCount()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(n)
@@ -546,7 +544,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -576,8 +574,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// space available in the backlog.
// Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4e3c5419c..9ff7ac261 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -86,6 +86,19 @@ type handshake struct {
// rcvWndScale is the receive window scale, as defined in RFC 1323.
rcvWndScale int
+
+ // startTime is the time at which the first SYN/SYN-ACK was sent.
+ startTime time.Time
+
+ // deferAccept if non-zero will drop the final ACK for a passive
+ // handshake till an ACK segment with data is received or the timeout is
+ // hit.
+ deferAccept time.Duration
+
+ // acked is true if the the final ACK for a 3-way handshake has
+ // been received. This is required to stop retransmitting the
+ // original SYN-ACK when deferAccept is enabled.
+ acked bool
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
@@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
return h
}
+func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
+ h := newHandshake(ep, rcvWnd)
+ h.resetToSynRcvd(isn, irs, opts, deferAccept)
+ return h
+}
+
// FindWndScale determines the window scale to use for the given maximum window
// size.
func FindWndScale(wnd seqnum.Size) int {
@@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.ackNum = irs + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
+ h.deferAccept = deferAccept
h.ep.mu.Lock()
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
@@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
+ // If deferAccept is not zero and this is a bare ACK and the
+ // timeout is not hit then drop the ACK.
+ if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ h.acked = true
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
@@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+
h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
+ // If the segment has data then requeue it for the receiver
+ // to process it again once main loop is started.
+ if s.data.Size() > 0 {
+ s.incRef()
+ h.ep.enqueueSegment(s)
+ }
h.ep.mu.Unlock()
-
return nil
}
@@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.startTime = time.Now()
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
@@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error {
switch index, _ := s.Fetch(true); index {
case wakerForResend:
timeOut *= 2
- if timeOut > 60*time.Second {
+ if timeOut > MaxRTO {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ // Resend the SYN/SYN-ACK only if the following conditions hold.
+ // - It's an active handshake (deferAccept does not apply)
+ // - It's a passive handshake and we have not yet got the final-ACK.
+ // - It's a passive handshake and we got an ACK but deferAccept is
+ // enabled and we are now past the deferAccept duration.
+ // The last is required to provide a way for the peer to complete
+ // the connection with another ACK or data (as ACKs are never
+ // retransmitted on their own).
+ if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ }
case wakerForNotification:
n := h.ep.fetchNotifications()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 13718ff55..b5a8e15ee 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -498,6 +498,13 @@ type endpoint struct {
// without any data being acked.
userTimeout time.Duration
+ // deferAccept if non-zero specifies a user specified time during
+ // which the final ACK of a handshake will be dropped provided the
+ // ACK is a bare ACK and carries no data. If the timeout is crossed then
+ // the bare ACK is accepted and the connection is delivered to the
+ // listener.
+ deferAccept time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
@@ -1574,6 +1581,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ if time.Duration(v) > MaxRTO {
+ v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ }
+ e.deferAccept = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1798,6 +1814,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ *o = tcpip.TCPDeferAcceptOption(e.deferAccept)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2025,8 +2047,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// work mutex is available.
if e.workMu.TryLock() {
e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.notifyProtocolGoroutine(notifyTickleWorker)
+ // We need to double check here to make
+ // sure worker has not transitioned the
+ // endpoint out of a connected state
+ // before trying to send a reset.
+ if e.EndpointState().connected() {
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
e.mu.Unlock()
e.workMu.Unlock()
} else {
@@ -2149,9 +2177,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
-func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+func (e *endpoint) startAcceptedLoop() {
e.mu.Lock()
- e.waiterQueue = waiterQueue
e.workerRunning = true
e.mu.Unlock()
wakerInitDone := make(chan struct{})
@@ -2177,7 +2204,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
-
return n, n.waiterQueue, nil
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 7eb613be5..c9ee5bf06 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
- })
+ }, queue)
if err != nil {
return nil, err
}
// Start the protocol goroutine.
- ep.startAcceptedLoop(queue)
+ ep.startAcceptedLoop()
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index df2fb1071..2c1505067 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -6787,3 +6788,183 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
),
)
}
+
+func TestTCPDeferAccept(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give a bit of time for the socket to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestTCPDeferAcceptTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Sleep for a little of the tcpDeferAccept timeout.
+ time.Sleep(tcpDeferAccept + 100*time.Millisecond)
+
+ // On timeout expiry we should get a SYN-ACK retransmission.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give sometime for the endpoint to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestResetDuringClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ iss := seqnum.Value(789)
+ c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
+ // Send some data to make sure there is some unread
+ // data to trigger a reset on c.Close.
+ irs := c.IRS
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss.Add(1),
+ AckNum: irs.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(irs.Add(1))),
+ checker.AckNum(uint32(iss.Add(5)))))
+
+ // Close in a separate goroutine so that we can trigger
+ // a race with the RST we send below. This should not
+ // panic due to the route being released depeding on
+ // whether Close() sends an active RST or the RST sent
+ // below is processed by the worker first.
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(5),
+ AckNum: c.IRS.Add(5),
+ RcvWnd: 30000,
+ Flags: header.TCPFlagRst,
+ })
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.EP.Close()
+ }()
+
+ wg.Wait()
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index b33ec2087..ce6a2c31d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,7 +6,6 @@ go_library(
name = "context",
testonly = 1,
srcs = ["context.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
"//visibility:public",
],
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 822907998..730ac4292 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -18,6 +18,7 @@ package context
import (
"bytes"
+ "context"
"testing"
"time"
@@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack {
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
- select {
- case <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), wait)
+ if _, ok := c.linkEP.ReadContext(ctx); ok {
c.t.Fatal(errMsg)
-
- case <-time.After(wait):
}
}
@@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) {
// 2 seconds.
func (c *Context) GetPacket() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
- if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
- }
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
+ if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
+ c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
}
- return nil
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// GetPacketNonBlocking reads a packet from the link layer endpoint
@@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte {
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
- default:
+ p, ok := c.linkEP.Read()
+ if !ok {
return nil
}
+
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
@@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
-
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
- return nil
+ checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
+ return b
}
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 43fcc27f0..3ad6994a7 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "tcpconntrack",
srcs = ["tcp_conntrack.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 57ff123e3..adc908e24 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -25,7 +24,6 @@ go_library(
"protocol.go",
"udp_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c6927cfe3..f0ff3fe71 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "context"
"fmt"
"math/rand"
"testing"
@@ -357,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- h := flow.header4Tuple(outgoing)
- checkers := append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
}
- return nil
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ h := flow.header4Tuple(outgoing)
+ checkers = append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
+ return b
}
// injectPacket creates a packet of the given flow and with the given payload,
@@ -1541,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
- }
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
}
})
}
@@ -1615,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
+
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
}
})
}