summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD7
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD6
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go22
-rw-r--r--pkg/tcpip/buffer/BUILD6
-rw-r--r--pkg/tcpip/checker/BUILD3
-rw-r--r--pkg/tcpip/checker/checker.go22
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD6
-rw-r--r--pkg/tcpip/header/BUILD7
-rw-r--r--pkg/tcpip/header/checksum.go133
-rw-r--r--pkg/tcpip/header/checksum_test.go62
-rw-r--r--pkg/tcpip/header/ipv6.go7
-rw-r--r--pkg/tcpip/header/ndp_router_solicit.go36
-rw-r--r--pkg/tcpip/iptables/BUILD6
-rw-r--r--pkg/tcpip/iptables/iptables.go86
-rw-r--r--pkg/tcpip/iptables/targets.go9
-rw-r--r--pkg/tcpip/iptables/types.go17
-rw-r--r--pkg/tcpip/link/channel/BUILD3
-rw-r--r--pkg/tcpip/link/channel/channel.go43
-rw-r--r--pkg/tcpip/link/fdbased/BUILD6
-rw-r--r--pkg/tcpip/link/loopback/BUILD3
-rw-r--r--pkg/tcpip/link/muxed/BUILD6
-rw-r--r--pkg/tcpip/link/rawfile/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD6
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD6
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD6
-rw-r--r--pkg/tcpip/link/sniffer/BUILD3
-rw-r--r--pkg/tcpip/link/tun/BUILD3
-rw-r--r--pkg/tcpip/link/waitable/BUILD6
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD4
-rw-r--r--pkg/tcpip/network/arp/arp.go2
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD6
-rw-r--r--pkg/tcpip/network/hash/BUILD3
-rw-r--r--pkg/tcpip/network/ip_test.go21
-rw-r--r--pkg/tcpip/network/ipv4/BUILD5
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go13
-rw-r--r--pkg/tcpip/network/ipv6/BUILD6
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go9
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go2
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go2
-rw-r--r--pkg/tcpip/ports/BUILD6
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD2
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go2
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD2
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go2
-rw-r--r--pkg/tcpip/seqnum/BUILD3
-rw-r--r--pkg/tcpip/stack/BUILD8
-rw-r--r--pkg/tcpip/stack/ndp.go371
-rw-r--r--pkg/tcpip/stack/ndp_test.go418
-rw-r--r--pkg/tcpip/stack/nic.go370
-rw-r--r--pkg/tcpip/stack/registration.go2
-rw-r--r--pkg/tcpip/stack/stack.go23
-rw-r--r--pkg/tcpip/stack/stack_test.go6
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go54
-rw-r--r--pkg/tcpip/stack/transport_test.go6
-rw-r--r--pkg/tcpip/tcpip.go28
-rw-r--r--pkg/tcpip/transport/icmp/BUILD3
-rw-r--r--pkg/tcpip/transport/packet/BUILD3
-rw-r--r--pkg/tcpip/transport/raw/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/BUILD19
-rw-r--r--pkg/tcpip/transport/tcp/accept.go10
-rw-r--r--pkg/tcpip/transport/tcp/connect.go339
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go224
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go299
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go30
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go11
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go21
-rw-r--r--pkg/tcpip/transport/tcp/snd.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go94
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD3
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go86
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go40
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go311
76 files changed, 2503 insertions, 938 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index ebc8d0209..26f7ba86b 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -12,12 +11,10 @@ go_library(
"time_unsafe.go",
"timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
"//pkg/tcpip/buffer",
- "//pkg/tcpip/iptables",
"//pkg/waiter",
],
)
@@ -26,7 +23,7 @@ go_test(
name = "tcpip_test",
size = "small",
srcs = ["tcpip_test.go"],
- embed = [":tcpip"],
+ library = ":tcpip",
)
go_test(
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 3df7d18d3..a984f1712 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "gonet",
srcs = ["gonet.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -23,7 +21,7 @@ go_test(
name = "gonet_test",
size = "small",
srcs = ["gonet_test.go"],
- embed = [":gonet"],
+ library = ":gonet",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index a2f44b496..711969b9b 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -556,6 +556,17 @@ type PacketConn struct {
wq *waiter.Queue
}
+// NewPacketConn creates a new PacketConn.
+func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn {
+ c := &PacketConn{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
// DialUDP creates a new PacketConn.
//
// If laddr is nil, a local address is automatically chosen.
@@ -580,12 +591,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := PacketConn{
- stack: s,
- ep: ep,
- wq: &wq,
- }
- c.deadlineTimer.init()
+ c := NewPacketConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -599,7 +605,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- return &c, nil
+ return c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
@@ -622,7 +628,7 @@ func (c *PacketConn) RemoteAddr() net.Addr {
if err != nil {
return nil
}
- return fullToTCPAddr(a)
+ return fullToUDPAddr(a)
}
// Read implements net.Conn.Read
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index d6c31bfa2..563bc78ea 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"prependable.go",
"view.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer",
visibility = ["//visibility:public"],
)
@@ -17,5 +15,5 @@ go_test(
name = "buffer_test",
size = "small",
srcs = ["view_test.go"],
- embed = [":buffer"],
+ library = ":buffer",
)
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index b6fa6fc37..ed434807f 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,7 +6,6 @@ go_library(
name = "checker",
testonly = 1,
srcs = ["checker.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/checker",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f15bf1f1..885d773b0 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network)
// TransportChecker is a function to check a property of a transport packet.
type TransportChecker func(*testing.T, header.Transport)
+// ControlMessagesChecker is a function to check a property of ancillary data.
+type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
+
// IPv4 checks the validity and properties of the given IPv4 packet. It is
// expected to be used in conjunction with other network checkers for specific
// properties. For example, to check the source and destination address, one
@@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
+func ReceiveTOS(want uint8) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTOS {
+ t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
+ }
+ if got := cm.TOS; got != want {
+ t.Fatalf("got cm.TOS = %d, want %d", got, want)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -754,3 +770,9 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
}
}
}
+
+// NDPRS creates a checker that checks that the packet contains a valid NDP
+// Router Solicitation message (as per the raw wire format).
+func NDPRS() NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize)
+}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index e648efa71..ff2719291 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "jenkins",
srcs = ["jenkins.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins",
visibility = ["//visibility:public"],
)
@@ -16,5 +14,5 @@ go_test(
srcs = [
"jenkins_test.go",
],
- embed = [":jenkins"],
+ library = ":jenkins",
)
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index f2061c778..9da0d71f8 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -20,10 +19,10 @@ go_library(
"ndp_neighbor_solicit.go",
"ndp_options.go",
"ndp_router_advert.go",
+ "ndp_router_solicit.go",
"tcp.go",
"udp.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/header",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -58,7 +57,7 @@ go_test(
"eth_test.go",
"ndp_test.go",
],
- embed = [":header"],
+ library = ":header",
deps = [
"//pkg/tcpip",
"@com_github_google_go-cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 9749c7f4d..14a4b2b44 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -45,12 +45,139 @@ func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
+func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
+ v := initial
+
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
+ l := len(buf)
+ odd = l&1 != 0
+ if odd {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+ for (l - 64) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[64:]
+ l = l - 64
+ }
+ if (l - 32) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[32:]
+ l = l - 32
+ }
+ if (l - 16) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[16:]
+ l = l - 16
+ }
+ if (l - 8) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ buf = buf[8:]
+ l = l - 8
+ }
+ if (l - 4) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ buf = buf[4:]
+ l = l - 4
+ }
+
+ // At this point since l was even before we started unrolling
+ // there can be only two bytes left to add.
+ if l != 0 {
+ v += (uint32(buf[0]) << 8) + uint32(buf[1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
+}
+
+// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given byte array. This function uses a non-optimized implementation. Its
+// only retained for reference and to use as a benchmark/test. Most code should
+// use the header.Checksum function.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumOld(buf []byte, initial uint16) uint16 {
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
+}
+
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
-// given byte array.
+// given byte array. This function uses an optimized unrolled version of the
+// checksum algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
- s, _ := calculateChecksum(buf, false, uint32(initial))
+ s, _ := unrolledCalculateChecksum(buf, false, uint32(initial))
return s
}
@@ -86,7 +213,7 @@ func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, siz
}
v = v[:l]
- sum, odd = calculateChecksum(v, odd, uint32(sum))
+ sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum))
size -= len(v)
if size == 0 {
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 86b466c1c..309403482 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -17,6 +17,8 @@
package header_test
import (
+ "fmt"
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) {
})
}
}
+
+func TestChecksum(t *testing.T) {
+ var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
+ type testCase struct {
+ buf []byte
+ initial uint16
+ csumOrig uint16
+ csumNew uint16
+ }
+ testCases := make([]testCase, 100000)
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for i := range testCases {
+ testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
+ testCases[i].initial = uint16(rnd.Intn(65536))
+ rnd.Read(testCases[i].buf)
+ }
+
+ for i := range testCases {
+ testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
+ testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
+ if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
+ t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
+ }
+ }
+}
+
+func BenchmarkChecksum(b *testing.B) {
+ var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
+
+ checkSumImpls := []struct {
+ fn func([]byte, uint16) uint16
+ name string
+ }{
+ {header.ChecksumOld, fmt.Sprintf("checksum_old")},
+ {header.Checksum, fmt.Sprintf("checksum")},
+ }
+
+ for _, csumImpl := range checkSumImpls {
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for _, bufSz := range bufSizes {
+ b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
+ tc := struct {
+ buf []byte
+ initial uint16
+ csum uint16
+ }{
+ buf: make([]byte, bufSz),
+ initial: uint16(rnd.Intn(65536)),
+ }
+ rnd.Read(tc.buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ tc.csum = csumImpl.fn(tc.buf, tc.initial)
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 83425c614..70e6ce095 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -84,6 +84,13 @@ const (
// The address is ff02::1.
IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ // IPv6AllRoutersMulticastAddress is a link-local multicast group that
+ // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers on a link.
+ //
+ // The address is ff02::2.
+ IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go
new file mode 100644
index 000000000..9e67ba95d
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_solicit.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.1 for more details.
+type NDPRouterSolicit []byte
+
+const (
+ // NDPRSMinimumSize is the minimum size of a valid NDP Router
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPRSMinimumSize = 4
+
+ // ndpRSOptionsOffset is the start of the NDP options in an
+ // NDPRouterSolicit.
+ ndpRSOptionsOffset = 4
+)
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpRSOptionsOffset:])
+}
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index 64769c333..d1b73cfdf 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,10 +9,10 @@ go_library(
"targets.go",
"types.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
- "//pkg/tcpip/buffer",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
],
)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index 647970133..4bfb3149e 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,6 +16,13 @@
// tool.
package iptables
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
// Table names.
const (
TablenameNat = "nat"
@@ -127,3 +134,80 @@ func EmptyFilterTable() Table {
UserChains: map[string]int{},
}
}
+
+// Check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
+ // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because
+ // we're missing features. Jumps, the call stack, etc. aren't checked
+ // for yet because we're yet to support them.
+
+ // Go through each table containing the hook.
+ for _, tablename := range it.Priorities[hook] {
+ switch verdict := it.checkTable(hook, pkt, tablename); verdict {
+ // If the table returns Accept, move on to the next table.
+ case Accept:
+ continue
+ // The Drop verdict is final.
+ case Drop:
+ return false
+ case Stolen, Queue, Repeat, None, Jump, Return, Continue:
+ panic(fmt.Sprintf("Unimplemented verdict %v.", verdict))
+ default:
+ panic(fmt.Sprintf("Unknown verdict %v.", verdict))
+ }
+ }
+
+ // Every table returned Accept.
+ return true
+}
+
+func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) Verdict {
+ // Start from ruleIdx and walk the list of rules until a rule gives us
+ // a verdict.
+ table := it.Tables[tablename]
+ for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ {
+ switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict {
+ // In either of these cases, this table is done with the packet.
+ case Accept, Drop:
+ return verdict
+ // Continue traversing the rules of the table.
+ case Continue:
+ continue
+ case Stolen, Queue, Repeat, None, Jump, Return:
+ panic(fmt.Sprintf("Unimplemented verdict %v.", verdict))
+ default:
+ panic(fmt.Sprintf("Unknown verdict %v.", verdict))
+ }
+ }
+
+ panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename))
+}
+
+// Precondition: pk.NetworkHeader is set.
+func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict {
+ rule := table.Rules[ruleIdx]
+
+ // First check whether the packet matches the IP header filter.
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
+ return Continue
+ }
+
+ // Go through each rule matcher. If they all match, run
+ // the rule target.
+ for _, matcher := range rule.Matchers {
+ matches, hotdrop := matcher.Match(hook, pkt, "")
+ if hotdrop {
+ return Drop
+ }
+ if !matches {
+ return Continue
+ }
+ }
+
+ // All the matchers matched, so run the target.
+ verdict, _ := rule.Target.Action(pkt)
+ return verdict
+}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
index b94a4c941..4dd281371 100644
--- a/pkg/tcpip/iptables/targets.go
+++ b/pkg/tcpip/iptables/targets.go
@@ -18,14 +18,14 @@ package iptables
import (
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
// UnconditionalAcceptTarget accepts all packets.
type UnconditionalAcceptTarget struct{}
// Action implements Target.Action.
-func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
+func (UnconditionalAcceptTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
return Accept, ""
}
@@ -33,7 +33,7 @@ func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict,
type UnconditionalDropTarget struct{}
// Action implements Target.Action.
-func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
+func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
return Drop, ""
}
@@ -42,8 +42,7 @@ func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, st
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
+func (ErrorTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
log.Warningf("ErrorTarget triggered.")
return Drop, ""
-
}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 540f8c0b4..50893cc55 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
package iptables
import (
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
// A Hook specifies one of the hooks built into the network stack.
@@ -153,6 +153,9 @@ func (table *Table) SetMetadata(metadata interface{}) {
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
type Rule struct {
+ // Filter holds basic IP filtering fields common to every rule.
+ Filter IPHeaderFilter
+
// Matchers is the list of matchers for this rule.
Matchers []Matcher
@@ -160,12 +163,18 @@ type Rule struct {
Target Target
}
+// IPHeaderFilter holds basic IP filtering data common to every rule.
+type IPHeaderFilter struct {
+ // Protocol matches the transport protocol.
+ Protocol tcpip.TransportProtocolNumber
+}
+
// A Matcher is the interface for matching packets.
type Matcher interface {
// Match returns whether the packet matches and whether the packet
// should be "hotdropped", i.e. dropped immediately. This is usually
// used for suspicious packets.
- Match(hook Hook, packet buffer.VectorisedView, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
@@ -173,5 +182,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the name of the chain to jump to.
- Action(packet buffer.VectorisedView) (Verdict, string)
+ Action(packet tcpip.PacketBuffer) (Verdict, string)
}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 7dbc05754..3974c464e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "channel",
srcs = ["channel.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 70188551f..71b9da797 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -18,6 +18,8 @@
package channel
import (
+ "context"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -38,25 +40,52 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
GSO bool
- // C is where outbound packets are queued.
- C chan PacketInfo
+ // c is where outbound packets are queued.
+ c chan PacketInfo
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- C: make(chan PacketInfo, size),
+ c: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
}
+// Close closes e. Further packet injections will panic. Reads continue to
+// succeed until all packets are read.
+func (e *Endpoint) Close() {
+ close(e.c)
+}
+
+// Read does non-blocking read for one packet from the outbound packet queue.
+func (e *Endpoint) Read() (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+// ReadContext does blocking read for one packet from the outbound packet queue.
+// It can be cancelled by ctx, and in this case, it returns false.
+func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ select {
+ case pkt := <-e.c:
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
select {
- case <-e.C:
+ case <-e.c:
c++
default:
return c
@@ -125,7 +154,7 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
@@ -150,7 +179,7 @@ packetLoop:
}
select {
- case e.C <- p:
+ case e.c <- p:
n++
default:
break packetLoop
@@ -169,7 +198,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
}
select {
- case e.C <- p:
+ case e.c <- p:
default:
}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 66cc53ed4..abe725548 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -13,7 +12,6 @@ go_library(
"mmap_unsafe.go",
"packet_dispatchers.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -30,7 +28,7 @@ go_test(
name = "fdbased_test",
size = "small",
srcs = ["endpoint_test.go"],
- embed = [":fdbased"],
+ library = ":fdbased",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
index f35fcdff4..6bf3805b7 100644
--- a/pkg/tcpip/link/loopback/BUILD
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "loopback",
srcs = ["loopback.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 1ac7948b6..82b441b79 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "muxed",
srcs = ["injectable.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -19,7 +17,7 @@ go_test(
name = "muxed_test",
size = "small",
srcs = ["injectable_test.go"],
- embed = [":muxed"],
+ library = ":muxed",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index d8211e93d..14b527bc2 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -12,7 +12,6 @@ go_library(
"errors.go",
"rawfile_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 09165dd4c..13243ebbb 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,7 +10,6 @@ go_library(
"sharedmem_unsafe.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -30,7 +28,7 @@ go_test(
srcs = [
"sharedmem_test.go",
],
- embed = [":sharedmem"],
+ library = ":sharedmem",
deps = [
"//pkg/sync",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index a0d4ad0be..87020ec08 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,7 +10,6 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe",
visibility = ["//visibility:public"],
)
@@ -20,6 +18,6 @@ go_test(
srcs = [
"pipe_test.go",
],
- embed = [":pipe"],
+ library = ":pipe",
deps = ["//pkg/sync"],
)
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index 8c9234d54..3ba06af73 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -22,7 +20,7 @@ go_test(
srcs = [
"queue_test.go",
],
- embed = [":queue"],
+ library = ":queue",
deps = [
"//pkg/tcpip/link/sharedmem/pipe",
],
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index d6ae0368a..230a8d53a 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,7 +8,6 @@ go_library(
"pcap.go",
"sniffer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index a71a493fc..e5096ea38 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -1,10 +1,9 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "tun",
srcs = ["tun_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun",
visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 134837943..0956d2c65 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,7 +7,6 @@ go_library(
srcs = [
"waitable.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable",
visibility = ["//visibility:public"],
deps = [
"//pkg/gate",
@@ -23,7 +21,7 @@ go_test(
srcs = [
"waitable_test.go",
],
- embed = [":waitable"],
+ library = ":waitable",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 9d16ff8c9..6a4839fb8 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index e7617229b..eddf7b725 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "arp",
srcs = ["arp.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 42cacb8a6..1ceaebfbd 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -137,7 +137,7 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
if addrWithPrefix.Address != ProtocolAddress {
return nil, tcpip.ErrBadLocalAddress
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 8e6048a21..03cf03b6d 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,6 +15,7 @@
package arp_test
import (
+ "context"
"strconv"
"testing"
"time"
@@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP.C)
+ c.linkEP.Close()
}
func TestDirectRequest(t *testing.T) {
@@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) {
for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
- pi := <-c.linkEP.C
+ pi, _ := c.linkEP.ReadContext(context.Background())
if pi.Proto != arp.ProtocolNumber {
t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
@@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) {
}
inject(stackAddrBad)
- select {
- case pkt := <-c.linkEP.C:
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
+ ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
- case <-time.After(100 * time.Millisecond):
- // Sleep tests are gross, but this will only potentially flake
- // if there's a bug. If there is no bug this will reliably
- // succeed.
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index ed16076fd..d1c728ccf 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -24,7 +23,6 @@ go_library(
"reassembler.go",
"reassembler_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation",
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
@@ -42,6 +40,6 @@ go_test(
"fragmentation_test.go",
"reassembler_test.go",
],
- embed = [":fragmentation"],
+ library = ":fragmentation",
deps = ["//pkg/tcpip/buffer"],
)
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
index e6db5c0b0..872165866 100644
--- a/pkg/tcpip/network/hash/BUILD
+++ b/pkg/tcpip/network/hash/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "hash",
srcs = ["hash.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash",
visibility = ["//visibility:public"],
deps = [
"//pkg/rand",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f1bc33adf..f4d78f8c6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -212,10 +212,17 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
+func buildDummyStack() *stack.Stack {
+ return stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+}
+
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack())
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -250,7 +257,7 @@ func TestIPv4Send(t *testing.T) {
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -318,7 +325,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -385,7 +392,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -456,7 +463,7 @@ func TestIPv4FragmentationReceive(t *testing.T) {
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack())
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -491,7 +498,7 @@ func TestIPv6Send(t *testing.T) {
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
@@ -568,7 +575,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil)
+ ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index aeddfcdd4..0fef2b1f1 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,12 +8,12 @@ go_library(
"icmp.go",
"ipv4.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 4ee3d5b45..85512f9b2 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -54,10 +55,11 @@ type endpoint struct {
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
protocol *protocol
+ stack *stack.Stack
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
@@ -66,6 +68,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
protocol: p,
+ stack: st,
}
return e, nil
@@ -350,6 +353,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
}
pkt.NetworkHeader = headerView[:h.HeaderLength()]
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(iptables.Input, pkt); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+
hlen := int(h.HeaderLength())
tlen := int(h.TotalLength())
pkt.Data.TrimFront(hlen)
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index e4e273460..fb11874c6 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"icmp.go",
"ipv6.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -27,7 +25,7 @@ go_test(
"ipv6_test.go",
"ndp_test.go",
],
- embed = [":ipv6"],
+ library = ":ipv6",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 335f634d5..7a6820643 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "context"
"reflect"
"strings"
"testing"
@@ -109,7 +110,7 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
if err != nil {
t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
}
@@ -264,8 +265,8 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP0.C)
- close(c.linkEP1.C)
+ c.linkEP0.Close()
+ c.linkEP1.Close()
}
type routeArgs struct {
@@ -276,7 +277,7 @@ type routeArgs struct {
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pi := <-args.src.C
+ pi, _ := args.src.ReadContext(context.Background())
{
views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 58c3c79b9..180a480fd 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -221,7 +221,7 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
return &endpoint{
nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 0dbce14a0..fe895b376 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -62,7 +62,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
if err != nil {
t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index a6ef3bdcc..2bad05a2e 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,12 +1,10 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "ports",
srcs = ["ports.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/ports",
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
@@ -17,7 +15,7 @@ go_library(
go_test(
name = "ports_test",
srcs = ["ports_test.go"],
- embed = [":ports"],
+ library = ":ports",
deps = [
"//pkg/tcpip",
],
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index d7496fde6..cf0a5fefe 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 2239c1e66..0ab089208 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -164,7 +164,7 @@ func main() {
// Create TCP endpoint.
var wq waiter.Queue
ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
+ if e != nil {
log.Fatal(e)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
index 875561566..43264b76d 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index bca73cbb1..9e37cab18 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -168,7 +168,7 @@ func main() {
// Create TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
- if err != nil {
+ if e != nil {
log.Fatal(e)
}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
index b31ddba2f..45f503845 100644
--- a/pkg/tcpip/seqnum/BUILD
+++ b/pkg/tcpip/seqnum/BUILD
@@ -1,10 +1,9 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "seqnum",
srcs = ["seqnum.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum",
visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 705e984c1..f5b750046 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -30,7 +29,6 @@ go_library(
"stack_global_state.go",
"transport_demuxer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/stack",
visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
@@ -51,7 +49,7 @@ go_library(
go_test(
name = "stack_x_test",
- size = "small",
+ size = "medium",
srcs = [
"ndp_test.go",
"stack_test.go",
@@ -81,7 +79,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = ["linkaddrcache_test.go"],
- embed = [":stack"],
+ library = ":stack",
deps = [
"//pkg/sleep",
"//pkg/sync",
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index a9dd322db..d983ac390 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -15,8 +15,8 @@
package stack
import (
- "fmt"
"log"
+ "math/rand"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,24 +38,36 @@ const (
// Default = 1s (from RFC 4861 section 10).
defaultRetransmitTimer = time.Second
+ // defaultMaxRtrSolicitations is the default number of Router
+ // Solicitation messages to send when a NIC becomes enabled.
+ //
+ // Default = 3 (from RFC 4861 section 10).
+ defaultMaxRtrSolicitations = 3
+
+ // defaultRtrSolicitationInterval is the default amount of time between
+ // sending Router Solicitation messages.
+ //
+ // Default = 4s (from 4861 section 10).
+ defaultRtrSolicitationInterval = 4 * time.Second
+
+ // defaultMaxRtrSolicitationDelay is the default maximum amount of time
+ // to wait before sending the first Router Solicitation message.
+ //
+ // Default = 1s (from 4861 section 10).
+ defaultMaxRtrSolicitationDelay = time.Second
+
// defaultHandleRAs is the default configuration for whether or not to
// handle incoming Router Advertisements as a host.
- //
- // Default = true.
defaultHandleRAs = true
// defaultDiscoverDefaultRouters is the default configuration for
// whether or not to discover default routers from incoming Router
// Advertisements, as a host.
- //
- // Default = true.
defaultDiscoverDefaultRouters = true
// defaultDiscoverOnLinkPrefixes is the default configuration for
// whether or not to discover on-link prefixes from incoming Router
// Advertisements' Prefix Information option, as a host.
- //
- // Default = true.
defaultDiscoverOnLinkPrefixes = true
// defaultAutoGenGlobalAddresses is the default configuration for
@@ -74,26 +86,31 @@ const (
// value of 0 means unspecified, so the smallest valid value is 1.
// Note, the unit of the RetransmitTimer field in the Router
// Advertisement is milliseconds.
- //
- // Min = 1ms.
minimumRetransmitTimer = time.Millisecond
+ // minimumRtrSolicitationInterval is the minimum amount of time to wait
+ // between sending Router Solicitation messages. This limit is imposed
+ // to make sure that Router Solicitation messages are not sent all at
+ // once, defeating the purpose of sending the initial few messages.
+ minimumRtrSolicitationInterval = 500 * time.Millisecond
+
+ // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait
+ // before sending the first Router Solicitation message. It is 0 because
+ // we cannot have a negative delay.
+ minimumMaxRtrSolicitationDelay = 0
+
// MaxDiscoveredDefaultRouters is the maximum number of discovered
// default routers. The stack should stop discovering new routers after
// discovering MaxDiscoveredDefaultRouters routers.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- //
- // Max = 10.
MaxDiscoveredDefaultRouters = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
// prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
// prefixes.
- //
- // Max = 10.
MaxDiscoveredOnLinkPrefixes = 10
// validPrefixLenForAutoGen is the expected prefix length that an
@@ -245,9 +262,24 @@ type NDPConfigurations struct {
// The amount of time to wait between sending Neighbor solicitation
// messages.
//
- // Must be greater than 0.5s.
+ // Must be greater than or equal to 1ms.
RetransmitTimer time.Duration
+ // The number of Router Solicitation messages to send when the NIC
+ // becomes enabled.
+ MaxRtrSolicitations uint8
+
+ // The amount of time between transmitting Router Solicitation messages.
+ //
+ // Must be greater than or equal to 0.5s.
+ RtrSolicitationInterval time.Duration
+
+ // The maximum amount of time before transmitting the first Router
+ // Solicitation message.
+ //
+ // Must be greater than or equal to 0s.
+ MaxRtrSolicitationDelay time.Duration
+
// HandleRAs determines whether or not Router Advertisements will be
// processed.
HandleRAs bool
@@ -278,12 +310,15 @@ type NDPConfigurations struct {
// default values.
func DefaultNDPConfigurations() NDPConfigurations {
return NDPConfigurations{
- DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
- RetransmitTimer: defaultRetransmitTimer,
- HandleRAs: defaultHandleRAs,
- DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
- DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
- AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
}
}
@@ -292,10 +327,24 @@ func DefaultNDPConfigurations() NDPConfigurations {
//
// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
// defaultRetransmitTimer will be used.
+//
+// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then
+// a value of defaultRtrSolicitationInterval will be used.
+//
+// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then
+// a value of defaultMaxRtrSolicitationDelay will be used.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
}
+
+ if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
+ c.RtrSolicitationInterval = defaultRtrSolicitationInterval
+ }
+
+ if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
+ c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
+ }
}
// ndpState is the per-interface NDP state.
@@ -316,6 +365,10 @@ type ndpState struct {
// Information option.
onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+ // The timer used to send the next router solicitation message.
+ // If routers are being solicited, rtrSolicitTimer MUST NOT be nil.
+ rtrSolicitTimer *time.Timer
+
// The addresses generated by SLAAC.
autoGenAddresses map[tcpip.Address]autoGenAddressState
@@ -375,87 +428,93 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
return tcpip.ErrAddressFamilyNotSupported
}
- // Should not attempt to perform DAD on an address that is currently in
- // the DAD process.
+ if ref.getKind() != permanentTentative {
+ // The endpoint should be marked as tentative since we are starting DAD.
+ log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in the
+ // DAD process.
if _, ok := ndp.dad[addr]; ok {
- // Should never happen because we should only ever call this
- // function for newly created addresses. If we attemped to
- // "add" an address that already existed, we would returned an
- // error since we attempted to add a duplicate address, or its
- // reference count would have been increased without doing the
- // work that would have been done for an address that was brand
- // new. See NIC.addPermanentAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ // Should never happen because we should only ever call this function for
+ // newly created addresses. If we attemped to "add" an address that already
+ // existed, we would get an error since we attempted to add a duplicate
+ // address, or its reference count would have been increased without doing
+ // the work that would have been done for an address that was brand new.
+ // See NIC.addAddressLocked.
+ log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())
}
remaining := ndp.configs.DupAddrDetectTransmits
-
- {
- done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
- if err != nil {
- return err
- }
- if done {
- return nil
- }
+ if remaining == 0 {
+ ref.setKind(permanent)
+ return nil
}
- remaining--
-
var done bool
var timer *time.Timer
- timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() {
- var d bool
- var err *tcpip.Error
-
- // doDadIteration does a single iteration of the DAD loop.
- //
- // Returns true if the integrator needs to be informed of DAD
- // completing.
- doDadIteration := func() bool {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD
- // timer fired after another goroutine already
- // obtained the NIC lock and stopped DAD before
- // this function obtained the NIC lock. Simply
- // return here and do nothing further.
- return false
- }
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the NIC's lock. This is effectively the same
+ // as starting a goroutine but we use a timer that fires immediately so we can
+ // reset it for the next DAD iteration.
+ timer = time.AfterFunc(0, func() {
+ ndp.nic.mu.RLock()
+ if done {
+ // If we reach this point, it means that the DAD timer fired after
+ // another goroutine already obtained the NIC lock and stopped DAD
+ // before this function obtained the NIC lock. Simply return here and do
+ // nothing further.
+ ndp.nic.mu.RUnlock()
+ return
+ }
- ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
- if !ok {
- // This should never happen.
- // We should have an endpoint for addr since we
- // are still performing DAD on it. If the
- // endpoint does not exist, but we are doing DAD
- // on it, then we started DAD at some point, but
- // forgot to stop it when the endpoint was
- // deleted.
- panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
- }
+ if ref.getKind() != permanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())
+ }
- d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
- if err != nil || d {
- delete(ndp.dad, addr)
+ dadDone := remaining == 0
+ ndp.nic.mu.RUnlock()
- if err != nil {
- log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
- }
+ var err *tcpip.Error
+ if !dadDone {
+ err = ndp.sendDADPacket(addr)
+ }
- // Let the integrator know DAD has completed.
- return true
- }
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that DAD was stopped after we released
+ // the NIC's read lock and before we obtained the write lock.
+ ndp.nic.mu.Unlock()
+ return
+ }
+ if dadDone {
+ // DAD has resolved.
+ ref.setKind(permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
remaining--
timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
- return false
+
+ ndp.nic.mu.Unlock()
+ return
}
- if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
- ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ // At this point we know that either DAD is done or we hit an error sending
+ // the last NDP NS. Either way, clean up addr's DAD state and let the
+ // integrator know DAD has completed.
+ delete(ndp.dad, addr)
+ ndp.nic.mu.Unlock()
+
+ if err != nil {
+ log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
+ }
+
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
}
})
@@ -467,44 +526,17 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
return nil
}
-// doDuplicateAddressDetection is called on every iteration of the timer, and
-// when DAD starts.
-//
-// It handles resolving the address (if there are no more NS to send), or
-// sending the next NS if there are more NS to send.
+// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
+// addr.
//
-// This function must only be called by IPv6 addresses that are currently
-// tentative.
-//
-// The NIC that ndp belongs to (n) MUST be locked.
-//
-// Returns true if DAD has resolved; false if DAD is still ongoing.
-func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
- if ref.getKind() != permanentTentative {
- // The endpoint should still be marked as tentative
- // since we are still performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
- }
-
- if remaining == 0 {
- // DAD has resolved.
- ref.setKind(permanent)
- return true, nil
- }
-
- // Send a new NS.
+// addr must be a tentative IPv6 address on ndp's NIC.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
- if !ok {
- // This should never happen as if we have the
- // address, we should have the solicited-node
- // address.
- panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
- }
- // Use the unspecified address as the source address when performing
- // DAD.
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
@@ -514,15 +546,19 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- }); err != nil {
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, tcpip.PacketBuffer{Header: hdr},
+ ); err != nil {
sent.Dropped.Increment()
- return false, err
+ return err
}
sent.NeighborSolicit.Increment()
- return false, nil
+ return nil
}
// stopDuplicateAddressDetection ends a running Duplicate Address Detection
@@ -1108,7 +1144,7 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupHostOnlyState() {
- for addr, _ := range ndp.autoGenAddresses {
+ for addr := range ndp.autoGenAddresses {
ndp.invalidateAutoGenAddress(addr)
}
@@ -1116,7 +1152,7 @@ func (ndp *ndpState) cleanupHostOnlyState() {
log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got)
}
- for prefix, _ := range ndp.onLinkPrefixes {
+ for prefix := range ndp.onLinkPrefixes {
ndp.invalidateOnLinkPrefix(prefix)
}
@@ -1124,7 +1160,7 @@ func (ndp *ndpState) cleanupHostOnlyState() {
log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got)
}
- for router, _ := range ndp.defaultRouters {
+ for router := range ndp.defaultRouters {
ndp.invalidateDefaultRouter(router)
}
@@ -1132,3 +1168,84 @@ func (ndp *ndpState) cleanupHostOnlyState() {
log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got)
}
}
+
+// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
+// 6.3.7. If routers are already being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) startSolicitingRouters() {
+ if ndp.rtrSolicitTimer != nil {
+ // We are already soliciting routers.
+ return
+ }
+
+ remaining := ndp.configs.MaxRtrSolicitations
+ if remaining == 0 {
+ return
+ }
+
+ // Calculate the random delay before sending our first RS, as per RFC
+ // 4861 section 6.3.7.
+ var delay time.Duration
+ if ndp.configs.MaxRtrSolicitationDelay > 0 {
+ delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ }
+
+ ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
+ // Send an RS message with the unspecified source address.
+ ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing)
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
+
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ pkt := header.ICMPv6(hdr.Prepend(payloadSize))
+ pkt.SetType(header.ICMPv6RouterSolicit)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, tcpip.PacketBuffer{Header: hdr},
+ ); err != nil {
+ sent.Dropped.Increment()
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.RouterSolicit.Increment()
+ remaining--
+ }
+
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+ if remaining == 0 {
+ ndp.rtrSolicitTimer = nil
+ } else if ndp.rtrSolicitTimer != nil {
+ // Note, we need to explicitly check to make sure that
+ // the timer field is not nil because if it was nil but
+ // we still reached this point, then we know the NIC
+ // was requested to stop soliciting routers so we don't
+ // need to send the next Router Solicitation message.
+ ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval)
+ }
+ })
+
+}
+
+// stopSolicitingRouters stops soliciting routers. If routers are not currently
+// being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) stopSolicitingRouters() {
+ if ndp.rtrSolicitTimer == nil {
+ // Nothing to do.
+ return
+ }
+
+ ndp.rtrSolicitTimer.Stop()
+ ndp.rtrSolicitTimer = nil
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index d390c6312..ad2c6f601 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "context"
"encoding/binary"
"fmt"
"testing"
@@ -35,13 +36,14 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
- linkAddr1 = "\x02\x02\x03\x04\x05\x06"
- linkAddr2 = "\x02\x02\x03\x04\x05\x07"
- linkAddr3 = "\x02\x02\x03\x04\x05\x08"
- defaultTimeout = 100 * time.Millisecond
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ defaultTimeout = 100 * time.Millisecond
+ defaultAsyncEventTimeout = time.Second
)
var (
@@ -301,6 +303,8 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
func TestDADResolve(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
dupAddrDetectTransmits uint8
@@ -331,44 +335,36 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
-
- // Should have sent an NDP NS immediately.
- if got := stat.Value(); got != 1 {
- t.Fatalf("got NeighborSolicit = %d, want = 1", got)
-
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
// Address should not be considered bound to the NIC yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Wait for the remaining time - some delta (500ms), to
// make sure the address is still not resolved.
const delta = 500 * time.Millisecond
time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Wait for DAD to resolve.
@@ -385,8 +381,8 @@ func TestDADResolve(t *testing.T) {
if e.err != nil {
t.Fatal("got DAD error: ", e.err)
}
- if e.nicID != 1 {
- t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ if e.nicID != nicID {
+ t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID)
}
if e.addr != addr1 {
t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
@@ -395,22 +391,22 @@ func TestDADResolve(t *testing.T) {
t.Fatal("got DAD event w/ resolved = false, want = true")
}
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
}
// Should not have sent any more NS messages.
- if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p := <-e.C
+ p, _ := e.ReadContext(context.Background())
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
@@ -425,7 +421,6 @@ func TestDADResolve(t *testing.T) {
}
})
}
-
}
// TestDADFail tests to make sure that the DAD process fails if another node is
@@ -1093,7 +1088,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1110,7 +1105,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1349,7 +1344,7 @@ func TestPrefixDiscovery(t *testing.T) {
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout):
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1688,7 +1683,7 @@ func TestAutoGenAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1994,7 +1989,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2034,7 +2029,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2048,7 +2043,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -2080,7 +2075,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
@@ -2095,7 +2090,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -2220,7 +2215,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(minVLSeconds*time.Second + defaultTimeout):
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -2445,6 +2440,119 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
}
+// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
+// assigned to the NIC but is in the permanentExpired state.
+func TestAutoGenAddrAfterRemoval(t *testing.T) {
+ t.Parallel()
+
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
+
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
+
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
+
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+}
+
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
// is already assigned to the NIC, the static address remains.
func TestAutoGenAddrStaticConflict(t *testing.T) {
@@ -2595,7 +2703,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout):
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3098,3 +3206,223 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
expectNoDHCPv6Event()
}
+
+// TestRouterSolicitation tests the initial Router Solicitations that are sent
+// when a NIC newly becomes enabled.
+func TestRouterSolicitation(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxRtrSolicit uint8
+ rtrSolicitInt time.Duration
+ effectiveRtrSolicitInt time.Duration
+ maxRtrSolicitDelay time.Duration
+ effectiveMaxRtrSolicitDelay time.Duration
+ }{
+ {
+ name: "Single RS with delay",
+ maxRtrSolicit: 1,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ {
+ name: "Two RS with delay",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: 500 * time.Millisecond,
+ effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
+ },
+ {
+ name: "Single RS without delay",
+ maxRtrSolicit: 1,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS without delay and invalid zero interval",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: 0,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Three RS without delay",
+ maxRtrSolicit: 3,
+ rtrSolicitInt: 500 * time.Millisecond,
+ effectiveRtrSolicitInt: 500 * time.Millisecond,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS with invalid negative delay",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: -3 * time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+ e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(),
+ )
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet")
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Make sure each RS got sent at the right
+ // times.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ remaining--
+ }
+ for ; remaining > 0; remaining-- {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
+ waitForPkt(defaultAsyncEventTimeout)
+ }
+
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ }
+
+ // Make sure the counter got properly
+ // incremented.
+ if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
+ })
+}
+
+// TestStopStartSolicitingRouters tests that when forwarding is enabled or
+// disabled, router solicitations are stopped or started, respecitively.
+func TestStopStartSolicitingRouters(t *testing.T) {
+ t.Parallel()
+
+ const interval = 500 * time.Millisecond
+ const delay = time.Second
+ const maxRtrSolicitations = 3
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ ctx, _ := context.WithTimeout(context.Background(), timeout)
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Enable forwarding which should stop router solicitations.
+ s.SetForwarding(true)
+ ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before forwarding was enabled.
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("Should not have sent more than one RS message")
+ }
+ }
+
+ // Enabling forwarding again should do nothing.
+ s.SetForwarding(true)
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after becoming a router")
+ }
+
+ // Disable forwarding which should start router solicitations.
+ s.SetForwarding(false)
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Disabling forwarding again should do nothing.
+ s.SetForwarding(false)
+ ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after becoming a router")
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index abf73fe33..7dad9a8cb 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -35,24 +35,21 @@ type NIC struct {
linkEP LinkEndpoint
context NICContext
- mu sync.RWMutex
- spoofing bool
- promiscuous bool
- primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- addressRanges []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
-
stats NICStats
- // ndp is the NDP related state for NIC.
- //
- // Note, read and write operations on ndp require that the NIC is
- // appropriately locked.
- ndp ndpState
+ mu struct {
+ sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ addressRanges []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ ndp ndpState
+ }
}
// NICStats includes transmitted and received stats.
@@ -97,15 +94,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- context: ctx,
- primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
- endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
- mcastJoins: make(map[NetworkEndpointID]int32),
- packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ context: ctx,
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -116,22 +109,26 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
Bytes: &tcpip.StatCounter{},
},
},
- ndp: ndpState{
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
- },
}
- nic.ndp.nic = nic
+ nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
+ nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
+ nic.mu.mcastJoins = make(map[NetworkEndpointID]int32)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.ndp = ndpState{
+ nic: nic,
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
+ }
// Register supported packet endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = []PacketEndpoint{}
}
for _, netProto := range stack.networkProtocols {
- nic.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
return nic
@@ -177,49 +174,72 @@ func (n *NIC) enable() *tcpip.Error {
}
// Do not auto-generate an IPv6 link-local address for loopback devices.
- if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() {
- return nil
- }
+ if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
+ var addr tcpip.Address
+ if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
+ } else {
+ l2addr := n.linkEP.LinkAddress()
- var addr tcpip.Address
- if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
- addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
- } else {
- l2addr := n.linkEP.LinkAddress()
+ // Only attempt to generate the link-local address if we have a valid MAC
+ // address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
- // Only attempt to generate the link-local address if we have a valid MAC
- // address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
- // LinkEndpoint.LinkAddress) before reaching this point.
- if !header.IsValidUnicastEthernetAddress(l2addr) {
- return nil
+ addr = header.LinkLocalAddr(l2addr)
}
- addr = header.LinkLocalAddr(l2addr)
+ if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ },
+ }, CanBePrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ return err
+ }
}
- _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
- },
- }, CanBePrimaryEndpoint)
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyways.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !n.stack.forwarding {
+ n.mu.ndp.startSolicitingRouters()
+ }
- return err
+ return nil
}
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
// routers, discovered on-link prefixes, and auto-generated addresses) will
-// be cleaned up/invalidated.
+// be cleaned up/invalidated and NDP router solicitations will be stopped.
func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
- n.ndp.cleanupHostOnlyState()
+ n.mu.ndp.cleanupHostOnlyState()
+ n.mu.ndp.stopSolicitingRouters()
+}
+
+// becomeIPv6Host transitions n into an IPv6 host.
+//
+// When transitioning into an IPv6 host, NDP router solicitations will be
+// started.
+func (n *NIC) becomeIPv6Host() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.mu.ndp.startSolicitingRouters()
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -231,13 +251,13 @@ func (n *NIC) attachLinkEndpoint() {
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
- n.promiscuous = enable
+ n.mu.promiscuous = enable
n.mu.Unlock()
}
func (n *NIC) isPromiscuousMode() bool {
n.mu.RLock()
- rv := n.promiscuous
+ rv := n.mu.promiscuous
n.mu.RUnlock()
return rv
}
@@ -249,7 +269,7 @@ func (n *NIC) isLoopback() bool {
// setSpoofing enables or disables address spoofing.
func (n *NIC) setSpoofing(enable bool) {
n.mu.Lock()
- n.spoofing = enable
+ n.mu.spoofing = enable
n.mu.Unlock()
}
@@ -268,8 +288,8 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr t
defer n.mu.RUnlock()
var deprecatedEndpoint *referencedNetworkEndpoint
- for _, r := range n.primary[protocol] {
- if !r.isValidForOutgoing() {
+ for _, r := range n.mu.primary[protocol] {
+ if !r.isValidForOutgoingRLocked() {
continue
}
@@ -319,7 +339,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
n.mu.RLock()
defer n.mu.RUnlock()
- primaryAddrs := n.primary[header.IPv6ProtocolNumber]
+ primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
if len(primaryAddrs) == 0 {
return nil
@@ -402,7 +422,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn
// hasPermanentAddrLocked returns true if n has a permanent (including currently
// tentative) address, addr.
func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return false
@@ -413,24 +433,54 @@ func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
return kind == permanent || kind == permanentTentative
}
+type getRefBehaviour int
+
+const (
+ // spoofing indicates that the NIC's spoofing flag should be observed when
+ // getting a NIC's referenced network endpoint.
+ spoofing getRefBehaviour = iota
+
+ // promiscuous indicates that the NIC's promiscuous flag should be observed
+ // when getting a NIC's referenced network endpoint.
+ promiscuous
+
+ // forceSpoofing indicates that the NIC should be assumed to be spoofing,
+ // regardless of what the NIC's spoofing flag is when getting a NIC's
+ // referenced network endpoint.
+ forceSpoofing
+)
+
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
+ return n.getRefOrCreateTemp(protocol, address, peb, spoofing)
}
// getRefEpOrCreateTemp returns the referenced network endpoint for the given
-// protocol and address. If none exists a temporary one may be created if
-// we are in promiscuous mode or spoofing.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
+// protocol and address.
+//
+// If none exists a temporary one may be created if we are in promiscuous mode
+// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
+// Similarly, spoofing will only be checked if spoofing is true.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok {
+ var spoofingOrPromiscuous bool
+ switch tempRef {
+ case spoofing:
+ spoofingOrPromiscuous = n.mu.spoofing
+ case promiscuous:
+ spoofingOrPromiscuous = n.mu.promiscuous
+ case forceSpoofing:
+ spoofingOrPromiscuous = true
+ }
+
+ if ref, ok := n.mu.endpoints[id]; ok {
// An endpoint with this id exists, check if it can be used and return it.
switch ref.getKind() {
case permanentExpired:
@@ -451,7 +501,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// the caller or if the address is found in the NIC's subnets.
createTempEP := spoofingOrPromiscuous
if !createTempEP {
- for _, sn := range n.addressRanges {
+ for _, sn := range n.mu.addressRanges {
// Skip the subnet address.
if address == sn.ID() {
continue
@@ -479,7 +529,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok {
+ if ref, ok := n.mu.endpoints[id]; ok {
// No need to check the type as we are ok with expired endpoints at this
// point.
if ref.tryIncRef() {
@@ -510,20 +560,36 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
-func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
- id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
- if ref, ok := n.endpoints[id]; ok {
+// addAddressLocked adds a new protocolAddress to n.
+//
+// If n already has the address in a non-permanent state, and the kind given is
+// permanent, that address will be promoted in place and its properties set to
+// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress.
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // TODO(b/141022673): Validate IP addresses before adding them.
+
+ // Sanity check.
+ id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address}
+ if ref, ok := n.mu.endpoints[id]; ok {
+ // Endpoint already exists.
+ if kind != permanent {
+ return nil, tcpip.ErrDuplicateAddress
+ }
switch ref.getKind() {
case permanentTentative, permanent:
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent and respect
- // the new peb.
+ // Promote the endpoint to become permanent and respect the new peb,
+ // configType and deprecated status.
if ref.tryIncRef() {
+ // TODO(b/147748385): Perform Duplicate Address Detection when promoting
+ // an IPv6 endpoint to permanent.
ref.setKind(permanent)
+ ref.deprecated = deprecated
+ ref.configType = configType
- refs := n.primary[ref.protocol]
+ refs := n.mu.primary[ref.protocol]
for i, r := range refs {
if r == ref {
switch peb {
@@ -533,9 +599,9 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
if i == 0 {
return ref, nil
}
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
case NeverPrimaryEndpoint:
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
return ref, nil
}
}
@@ -553,26 +619,13 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent, static, false)
-}
-
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
- // TODO(b/141022673): Validate IP address before adding them.
-
- // Sanity check.
- id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
- if _, ok := n.endpoints[id]; ok {
- // Endpoint already exists.
- return nil, tcpip.ErrDuplicateAddress
- }
-
netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
}
// Create the new network endpoint.
- ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP)
+ ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack)
if err != nil {
return nil, err
}
@@ -611,13 +664,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
- n.endpoints[id] = ref
+ n.mu.endpoints[id] = ref
n.insertPrimaryEndpointLocked(ref, peb)
// If we are adding a tentative IPv6 address, start DAD.
if isIPv6Unicast && kind == permanentTentative {
- if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
}
@@ -630,7 +683,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addPermanentAddressLocked(protocolAddress, peb)
+ _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */)
n.mu.Unlock()
return err
@@ -642,8 +695,8 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
- addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
- for nid, ref := range n.endpoints {
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
+ for nid, ref := range n.mu.endpoints {
// Don't include tentative, expired or temporary endpoints to
// avoid confusion and prevent the caller from using those.
switch ref.getKind() {
@@ -669,7 +722,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
defer n.mu.RUnlock()
var addrs []tcpip.ProtocolAddress
- for proto, list := range n.primary {
+ for proto, list := range n.mu.primary {
for _, ref := range list {
// Don't include tentative, expired or tempory endpoints
// to avoid confusion and prevent the caller from using
@@ -700,7 +753,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
n.mu.RLock()
defer n.mu.RUnlock()
- list, ok := n.primary[proto]
+ list, ok := n.mu.primary[proto]
if !ok {
return tcpip.AddressWithPrefix{}
}
@@ -743,7 +796,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
// address.
func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
n.mu.Lock()
- n.addressRanges = append(n.addressRanges, subnet)
+ n.mu.addressRanges = append(n.mu.addressRanges, subnet)
n.mu.Unlock()
}
@@ -752,23 +805,23 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
n.mu.Lock()
// Use the same underlying array.
- tmp := n.addressRanges[:0]
- for _, sub := range n.addressRanges {
+ tmp := n.mu.addressRanges[:0]
+ for _, sub := range n.mu.addressRanges {
if sub != subnet {
tmp = append(tmp, sub)
}
}
- n.addressRanges = tmp
+ n.mu.addressRanges = tmp
n.mu.Unlock()
}
-// Subnets returns the Subnets associated with this NIC.
+// AddressRanges returns the Subnets associated with this NIC.
func (n *NIC) AddressRanges() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
- sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints))
- for nid := range n.endpoints {
+ sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints))
+ for nid := range n.mu.endpoints {
sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
if err != nil {
// This should never happen as the mask has been carefully crafted to
@@ -777,7 +830,7 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
}
sns = append(sns, sn)
}
- return append(sns, n.addressRanges...)
+ return append(sns, n.mu.addressRanges...)
}
// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
@@ -787,9 +840,9 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
switch peb {
case CanBePrimaryEndpoint:
- n.primary[r.protocol] = append(n.primary[r.protocol], r)
+ n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r)
case FirstPrimaryEndpoint:
- n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...)
+ n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...)
}
}
@@ -801,7 +854,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
// and was waiting (on the lock) to be removed and 2) the same address was
// re-added in the meantime by removing this endpoint from the list and
// adding a new one.
- if n.endpoints[id] != r {
+ if n.mu.endpoints[id] != r {
return
}
@@ -809,11 +862,11 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
panic("Reference count dropped to zero before being removed")
}
- delete(n.endpoints, id)
- refs := n.primary[r.protocol]
+ delete(n.mu.endpoints, id)
+ refs := n.mu.primary[r.protocol]
for i, ref := range refs {
if ref == r {
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
break
}
}
@@ -828,7 +881,7 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
}
func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
- r, ok := n.endpoints[NetworkEndpointID{addr}]
+ r, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return tcpip.ErrBadLocalAddress
}
@@ -844,13 +897,13 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing a tentative IPv6 unicast address, stop
// DAD.
if kind == permanentTentative {
- n.ndp.stopDuplicateAddressDetection(addr)
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
}
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
if r.configType == slaac {
- n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
+ n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
}
}
@@ -900,23 +953,23 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A
// outlined in RFC 3810 section 5.
id := NetworkEndpointID{addr}
- joins := n.mcastJoins[id]
+ joins := n.mu.mcastJoins[id]
if joins == 0 {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
return tcpip.ErrUnknownProtocol
}
- if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint); err != nil {
+ }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
- n.mcastJoins[id] = joins + 1
+ n.mu.mcastJoins[id] = joins + 1
return nil
}
@@ -934,7 +987,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
// before leaveGroupLocked is called.
func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
- joins := n.mcastJoins[id]
+ joins := n.mu.mcastJoins[id]
switch joins {
case 0:
// There are no joins with this address on this NIC.
@@ -945,7 +998,7 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return err
}
}
- n.mcastJoins[id] = joins - 1
+ n.mu.mcastJoins[id] = joins - 1
return nil
}
@@ -958,7 +1011,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
-// the NIC receives a packet from the physical interface.
+// the NIC receives a packet from the link endpoint.
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
@@ -980,12 +1033,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// Are any packet sockets listening for this network protocol?
n.mu.RLock()
- packetEPs := n.packetEPs[protocol]
+ packetEPs := n.mu.packetEPs[protocol]
// Check whether there are packet sockets listening for every protocol.
// If we received a packet with protocol EthernetProtocolAll, then the
// previous for loop will have handled it.
if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...)
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
}
n.mu.RUnlock()
for _, ep := range packetEPs {
@@ -1003,6 +1056,14 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
src, dst := netProto.ParseAddresses(pkt.Data.First())
+ if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
+ // The source address is one of our own, so we never should have gotten a
+ // packet like this unless handleLocal is false. Loopback also calls this
+ // function even though the packets didn't come from the physical interface
+ // so don't drop those.
+ n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
if ref := n.getRef(protocol, dst); ref != nil {
handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
return
@@ -1015,7 +1076,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
if n.stack.Forwarding() {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
defer r.Release()
@@ -1026,8 +1087,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// Found a NIC.
n := r.ref.nic
n.mu.RLock()
- ref, ok := n.endpoints[NetworkEndpointID{dst}]
- ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
+ ref, ok := n.mu.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
n.mu.RUnlock()
if ok {
r.RemoteAddress = src
@@ -1053,7 +1114,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// If a packet socket handled the packet, don't treat it as invalid.
if len(packetEPs) == 0 {
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
}
}
@@ -1147,7 +1208,10 @@ func (n *NIC) Stack() *Stack {
// false. It will only return true if the address is associated with the NIC
// AND it is tentative.
func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return false
}
@@ -1163,7 +1227,7 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return tcpip.ErrBadAddress
}
@@ -1183,7 +1247,7 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) {
c.validate()
n.mu.Lock()
- n.ndp.configs = c
+ n.mu.ndp.configs = c
n.mu.Unlock()
}
@@ -1192,7 +1256,7 @@ func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
n.mu.Lock()
defer n.mu.Unlock()
- n.ndp.handleRA(ip, ra)
+ n.mu.ndp.handleRA(ip, ra)
}
type networkEndpointKind int32
@@ -1234,11 +1298,11 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
n.mu.Lock()
defer n.mu.Unlock()
- eps, ok := n.packetEPs[netProto]
+ eps, ok := n.mu.packetEPs[netProto]
if !ok {
return tcpip.ErrNotSupported
}
- n.packetEPs[netProto] = append(eps, ep)
+ n.mu.packetEPs[netProto] = append(eps, ep)
return nil
}
@@ -1247,14 +1311,14 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
n.mu.Lock()
defer n.mu.Unlock()
- eps, ok := n.packetEPs[netProto]
+ eps, ok := n.mu.packetEPs[netProto]
if !ok {
return
}
for i, epOther := range eps {
if epOther == ep {
- n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
return
}
}
@@ -1290,7 +1354,8 @@ type referencedNetworkEndpoint struct {
kind networkEndpointKind
// configType is the method that was used to configure this endpoint.
- // This must never change after the endpoint is added to a NIC.
+ // This must never change except during endpoint creation and promotion to
+ // permanent.
configType networkEndpointConfigType
// deprecated indicates whether or not the endpoint should be considered
@@ -1311,14 +1376,19 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed), or the NIC to be in spoofing mode.
func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
- return r.getKind() != permanentExpired || r.nic.spoofing
+ r.nic.mu.RLock()
+ defer r.nic.mu.RUnlock()
+
+ return r.isValidForOutgoingRLocked()
}
-// isValidForIncoming returns true if the endpoint can accept an incoming
-// packet. It requires the endpoint to not be marked expired (i.e., its address
-// has been removed), or the NIC to be in promiscuous mode.
-func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
- return r.getKind() != permanentExpired || r.nic.promiscuous
+// isValidForOutgoingRLocked returns true if the endpoint can be used to send
+// out a packet. It requires the endpoint to not be marked expired (i.e., its
+// address has been removed), or the NIC to be in spoofing mode.
+//
+// r's NIC must be read locked.
+func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
+ return r.getKind() != permanentExpired || r.nic.mu.spoofing
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2b8751d49..ec91f60dd 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -282,7 +282,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+ NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error)
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f8d89248e..7057b110e 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -423,7 +423,11 @@ type Stack struct {
// handleLocal allows non-loopback interfaces to loop packets.
handleLocal bool
- // tables are the iptables packet filtering and manipulation rules.
+ // tablesMu protects iptables.
+ tablesMu sync.RWMutex
+
+ // tables are the iptables packet filtering and manipulation rules. The are
+ // protected by tablesMu.`
tables iptables.IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -750,7 +754,9 @@ func (s *Stack) Stats() tcpip.Stats {
// SetForwarding enables or disables the packet forwarding between NICs.
//
// When forwarding becomes enabled, any host-only state on all NICs will be
-// cleaned up.
+// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
+// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
+// Solicitations will be stopped.
func (s *Stack) SetForwarding(enable bool) {
// TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
s.mu.Lock()
@@ -772,6 +778,10 @@ func (s *Stack) SetForwarding(enable bool) {
for _, nic := range s.nics {
nic.becomeIPv6Router()
}
+ } else {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Host()
+ }
}
}
@@ -912,7 +922,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
return false
}
-// NICSubnets returns a map of NICIDs to their associated subnets.
+// NICAddressRanges returns a map of NICIDs to their associated subnets.
func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1588,12 +1598,17 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
// IPTables returns the stack's iptables.
func (s *Stack) IPTables() iptables.IPTables {
- return s.tables
+ s.tablesMu.RLock()
+ t := s.tables
+ s.tablesMu.RUnlock()
+ return t
}
// SetIPTables sets the stack's iptables.
func (s *Stack) SetIPTables(ipt iptables.IPTables) {
+ s.tablesMu.Lock()
s.tables = ipt
+ s.tablesMu.Unlock()
}
// ICMPLimit returns the maximum number of ICMP messages that can be sent
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 4b3d18f1b..834fe9487 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -201,7 +201,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
@@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) {
Data: buf.ToVectorisedView(),
})
- select {
- case <-ep2.C:
- default:
+ if _, ok := ep2.Read(); !ok {
t.Fatal("Packet not forwarded")
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index f384a91de..d686e6eb8 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
return
}
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt)
+ transEP := selectEndpoint(id, mpep, epsByNic.seed)
+ if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
+ queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ epsByNic.mu.RUnlock()
+ return
+ }
+
+ transEP.HandlePacket(r, id, pkt)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNic.mu.Lock()
defer epsByNic.mu.Unlock()
@@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort
}
// This is a new binding.
- multiPortEp := &multiPortEndpoint{}
+ multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto}
multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
multiPortEp.reuse = reusePort
epsByNic.endpoints[bindToDevice] = multiPortEp
@@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// newTransportDemuxer.
type transportDemuxer struct {
// protocol is immutable.
- protocol map[protocolIDs]*transportEndpoints
+ protocol map[protocolIDs]*transportEndpoints
+ queuedProtocols map[protocolIDs]queuedTransportProtocol
+}
+
+// queuedTransportProtocol if supported by a protocol implementation will cause
+// the dispatcher to delivery packets to the QueuePacket method instead of
+// calling HandlePacket directly on the endpoint.
+type queuedTransportProtocol interface {
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
- d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+ d := &transportDemuxer{
+ protocol: make(map[protocolIDs]*transportEndpoints),
+ queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
+ }
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ protoIDs := protocolIDs{netProto, proto}
+ d.protocol[protoIDs] = &transportEndpoints{
endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
+ qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
+ if isQueued {
+ d.queuedProtocols[protoIDs] = qTransProto
+ }
}
}
@@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ demux *transportDemuxer
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
@@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
ep.mu.RLock()
+ queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
for i, endpoint := range ep.endpointsArr {
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
if i == len(ep.endpointsArr)-1 {
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ break
+ }
endpoint.HandlePacket(r, id, pkt)
break
}
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ continue
+ }
endpoint.HandlePacket(r, id, pkt.Clone())
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
@@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if epsByNic, ok := eps.endpoints[id]; ok {
// There was already a binding.
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// This is a new binding.
@@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
}
eps.endpoints[id] = epsByNic
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index f50604a8a..869c69a6d 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) {
t.Fatalf("Write failed: %v", err)
}
- var p channel.PacketInfo
- select {
- case p = <-ep2.C:
- default:
+ p, ok := ep2.Read()
+ if !ok {
t.Fatal("Response packet not forwarded")
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4a090ac86..59c9b3fb0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -40,7 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -322,7 +321,7 @@ type ControlMessages struct {
HasTOS bool
// TOS is the IPv4 type of service of the associated packet.
- TOS int8
+ TOS uint8
// HasTClass indicates whether Tclass is valid/set.
HasTClass bool
@@ -454,9 +453,6 @@ type Endpoint interface {
// NOTE: This method is a no-op for sockets other than TCP.
ModerateRecvBuf(copied int)
- // IPTables returns the iptables for this endpoint's stack.
- IPTables() (iptables.IPTables, error)
-
// Info returns a copy to the transport endpoint info.
Info() EndpointInfo
@@ -500,9 +496,13 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
+ // ancillary message is passed with incoming packets.
+ ReceiveTOSOption SockOptBool = iota
+
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
- V6OnlyOption SockOptBool = iota
+ V6OnlyOption
)
// SockOptInt represents socket options which values have the int type.
@@ -899,9 +899,13 @@ type IPStats struct {
// link layer in nic.DeliverNetworkPacket.
PacketsReceived *StatCounter
- // InvalidAddressesReceived is the total number of IP packets received
- // with an unknown or invalid destination address.
- InvalidAddressesReceived *StatCounter
+ // InvalidDestinationAddressesReceived is the total number of IP packets
+ // received with an unknown or invalid destination address.
+ InvalidDestinationAddressesReceived *StatCounter
+
+ // InvalidSourceAddressesReceived is the total number of IP packets received
+ // with a source address that should never have been received on the wire.
+ InvalidSourceAddressesReceived *StatCounter
// PacketsDelivered is the total number of incoming IP packets that
// are successfully delivered to the transport layer via HandlePacket.
@@ -934,9 +938,13 @@ type TCPStats struct {
PassiveConnectionOpenings *StatCounter
// CurrentEstablished is the number of TCP connections for which the
- // current state is either ESTABLISHED or CLOSE-WAIT.
+ // current state is ESTABLISHED.
CurrentEstablished *StatCounter
+ // CurrentConnected is the number of TCP connections that
+ // are in connected state.
+ CurrentConnected *StatCounter
+
// EstablishedResets is the number of times TCP connections have made
// a direct transition to the CLOSED state from either the
// ESTABLISHED state or the CLOSE-WAIT state.
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 3aa23d529..ac18ec5b1 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,7 +23,6 @@ go_library(
"icmp_packet_list.go",
"protocol.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index 4858d150c..d22de6b26 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -22,7 +22,6 @@ go_library(
"endpoint_state.go",
"packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2f2131ff7..c9baf4600 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,7 +23,6 @@ go_library(
"protocol.go",
"raw_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 353bd06f4..4acd9fb9a 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -16,6 +15,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tcp_endpoint_list",
+ out = "tcp_endpoint_list.go",
+ package = "tcp",
+ prefix = "endpoint",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*endpoint",
+ "Linker": "*endpoint",
+ },
+)
+
go_library(
name = "tcp",
srcs = [
@@ -23,6 +34,7 @@ go_library(
"connect.go",
"cubic.go",
"cubic_state.go",
+ "dispatcher.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
@@ -38,14 +50,13 @@ go_library(
"segment_state.go",
"snd.go",
"snd_state.go",
+ "tcp_endpoint_list.go",
"tcp_segment_list.go",
"timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
- "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 1ea996936..d469758eb 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
if l.listenEP != nil {
l.listenEP.mu.Lock()
- if l.listenEP.state != StateListen {
+ if l.listenEP.EndpointState() != StateListen {
l.listenEP.mu.Unlock()
return nil, tcpip.ErrConnectionAborted
}
@@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() {
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.pendingAccepted.Add(1)
defer e.pendingAccepted.Done()
acceptedChan := e.acceptedChan
e.mu.Unlock()
+
if state == StateListen {
acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
@@ -561,9 +562,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// Switch state to connected.
// We do not use transitionToStateEstablishedLocked here as there is
// no handshake state available when doing a SYN cookie based accept.
- n.stack.Stats().TCP.CurrentEstablished.Increment()
- n.state = StateEstablished
n.isConnectNotified = true
+ n.setEndpointState(StateEstablished)
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -596,7 +596,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
e.mu.Lock()
- e.state = StateClose
+ e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 613ec1775..4e3c5419c 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.mss = opts.MSS
h.sndWndScale = opts.WS
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
}
@@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// SYN-RCVD state.
h.state = handshakeSynRcvd
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
ttl := h.ep.ttl
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
// permits SACK. This is not explicitly defined in the RFC but
@@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
WS: h.rcvWndScale,
TS: h.ep.sendTSOk,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
@@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error {
WS: h.rcvWndScale,
TS: true,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
@@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
}
if e.sackPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error {
}
func (e *endpoint) handleClose() *tcpip.Error {
+ if !e.EndpointState().connected() {
+ return nil
+ }
// Drain the send queue.
e.handleWrite()
@@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
- if e.state == StateEstablished || e.state == StateCloseWait {
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
- }
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
@@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
}
// completeWorkerLocked is called by the worker goroutine when it's about to
-// exit. It marks the worker as completed and performs cleanup work if requested
-// by Close().
+// exit.
func (e *endpoint) completeWorkerLocked() {
+ // Worker is terminating(either due to moving to
+ // CLOSED or ERROR state, ensure we release all
+ // registrations port reservations even if the socket
+ // itself is not yet closed by the application.
e.workerRunning = false
if e.workerCleanup {
e.cleanupLocked()
@@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
e.rcvAutoParams.prevCopied = int(h.rcvWnd)
e.rcvListMu.Unlock()
}
- h.ep.stack.Stats().TCP.CurrentEstablished.Increment()
- e.state = StateEstablished
+ e.setEndpointState(StateEstablished)
}
// transitionToStateCloseLocked ensures that the endpoint is
@@ -927,11 +928,13 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.state == StateClose {
+ if e.EndpointState() == StateClose {
return
}
+ // Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
- e.state = StateClose
+ e.setEndpointState(StateClose)
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
e.stack.Stats().TCP.EstablishedClosed.Increment()
}
@@ -946,7 +949,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
s.decRef()
return
}
- ep.(*endpoint).enqueueSegment(s)
+ if ep.(*endpoint).enqueueSegment(s) {
+ ep.(*endpoint).newSegmentWaker.Assert()
+ }
}
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
@@ -955,9 +960,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
- s.decRef()
e.mu.Lock()
- switch e.state {
+ switch e.EndpointState() {
// In case of a RST in CLOSE-WAIT linux moves
// the socket to closed state with an error set
// to indicate EPIPE.
@@ -981,103 +985,53 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
e.mu.Unlock()
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+
+ // Notify protocol goroutine. This is required when
+ // handleSegment is invoked from the processor goroutine
+ // rather than the worker goroutine.
+ e.notifyProtocolGoroutine(notifyResetByPeer)
return false, tcpip.ErrConnectionReset
}
}
return true, nil
}
-// handleSegments pulls segments from the queue and processes them. It returns
-// no error if the protocol loop should continue, an error otherwise.
-func (e *endpoint) handleSegments() *tcpip.Error {
+// handleSegments processes all inbound segments.
+func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
+ if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ return nil
+ }
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
break
}
- // Invoke the tcp probe if installed.
- if e.probe != nil {
- e.probe(e.completeState())
+ cont, err := e.handleSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
}
-
- if s.flagIsSet(header.TCPFlagRst) {
- if ok, err := e.handleReset(s); !ok {
- return err
- }
- } else if s.flagIsSet(header.TCPFlagSyn) {
- // See: https://tools.ietf.org/html/rfc5961#section-4.1
- // 1) If the SYN bit is set, irrespective of the sequence number, TCP
- // MUST send an ACK (also referred to as challenge ACK) to the remote
- // peer:
- //
- // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
- //
- // After sending the acknowledgment, TCP MUST drop the unacceptable
- // segment and stop processing further.
- //
- // By sending an ACK, the remote peer is challenged to confirm the loss
- // of the previous connection and the request to start a new connection.
- // A legitimate peer, after restart, would not have a TCB in the
- // synchronized state. Thus, when the ACK arrives, the peer should send
- // a RST segment back with the sequence number derived from the ACK
- // field that caused the RST.
-
- // This RST will confirm that the remote peer has indeed closed the
- // previous connection. Upon receipt of a valid RST, the local TCP
- // endpoint MUST terminate its connection. The local TCP endpoint
- // should then rely on SYN retransmission from the remote end to
- // re-establish the connection.
-
- e.snd.sendAck()
- } else if s.flagIsSet(header.TCPFlagAck) {
- // Patch the window size in the segment according to the
- // send window scale.
- s.window <<= e.snd.sndWndScale
-
- // RFC 793, page 41 states that "once in the ESTABLISHED
- // state all segments must carry current acknowledgment
- // information."
- drop, err := e.rcv.handleRcvdSegment(s)
- if err != nil {
- s.decRef()
- return err
- }
- if drop {
- s.decRef()
- continue
- }
-
- // Now check if the received segment has caused us to transition
- // to a CLOSED state, if yes then terminate processing and do
- // not invoke the sender.
- e.mu.RLock()
- state := e.state
- e.mu.RUnlock()
- if state == StateClose {
- // When we get into StateClose while processing from the queue,
- // return immediately and let the protocolMainloop handle it.
- //
- // We can reach StateClose only while processing a previous segment
- // or a notification from the protocolMainLoop (caller goroutine).
- // This means that with this return, the segment dequeue below can
- // never occur on a closed endpoint.
- s.decRef()
- return nil
- }
- e.snd.handleRcvdSegment(s)
+ if !cont {
+ s.decRef()
+ return nil
}
- s.decRef()
}
- // If the queue is not empty, make sure we'll wake up in the next
- // iteration.
- if checkRequeue && !e.segmentQueue.empty() {
+ // When fastPath is true we don't want to wake up the worker
+ // goroutine. If the endpoint has more segments to process the
+ // dispatcher will call handleSegments again anyway.
+ if !fastPath && checkRequeue && !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
@@ -1086,11 +1040,88 @@ func (e *endpoint) handleSegments() *tcpip.Error {
e.snd.sendAck()
}
- e.resetKeepaliveTimer(true)
+ e.resetKeepaliveTimer(true /* receivedData */)
return nil
}
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if ok, err := e.handleReset(s); !ok {
+ return false, err
+ }
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ return false, err
+ }
+ if drop {
+ return true, nil
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return false, nil
+ }
+
+ e.snd.handleRcvdSegment(s)
+ }
+
+ return true, nil
+}
+
// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
@@ -1160,7 +1191,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1182,6 +1213,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
+ e.workMu.Unlock()
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1193,7 +1225,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
e.mu.Lock()
- h.ep.state = StateSynSent
+ h.ep.setEndpointState(StateSynSent)
e.mu.Unlock()
if err := h.execute(); err != nil {
@@ -1202,12 +1234,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.mu.Lock()
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
// Lock released below.
epilogue()
-
return err
}
}
@@ -1215,7 +1246,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- // Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
drained := e.drainDone != nil
e.mu.Unlock()
@@ -1224,8 +1254,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
<-e.undrain
}
- e.waiterQueue.Notify(waiter.EventOut)
-
// Set up the functions that will be called when the main protocol loop
// wakes up.
funcs := []struct {
@@ -1241,17 +1269,14 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
f: e.handleClose,
},
{
- w: &e.newSegmentWaker,
- f: e.handleSegments,
- },
- {
w: &closeWaker,
f: func() *tcpip.Error {
// This means the socket is being closed due
- // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
e.mu.Lock()
e.transitionToStateCloseLocked()
+ e.workerCleanup = true
e.mu.Unlock()
return nil
},
@@ -1267,6 +1292,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
},
},
{
+ w: &e.newSegmentWaker,
+ f: func() *tcpip.Error {
+ return e.handleSegments(false /* fastPath */)
+ },
+ },
+ {
w: &e.keepalive.waker,
f: e.keepaliveTimerExpired,
},
@@ -1293,14 +1324,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
if n&notifyReset != 0 {
- e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.mu.Unlock()
+ return tcpip.ErrConnectionAborted
+ }
+
+ if n&notifyResetByPeer != 0 {
+ return tcpip.ErrConnectionReset
}
if n&notifyClose != 0 && closeTimer == nil {
e.mu.Lock()
- if e.state == StateFinWait2 && e.closed {
+ if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
@@ -1320,11 +1353,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
- if err := e.handleSegments(); err != nil {
+ if err := e.handleSegments(false /* fastPath */); err != nil {
return err
}
}
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1349,14 +1382,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.AddWaker(funcs[i].w, i)
}
+ // Notify the caller that the waker initialization is complete and the
+ // endpoint is ready.
+ if wakerInitDone != nil {
+ close(wakerInitDone)
+ }
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.waiterQueue.Notify(waiter.EventOut)
+
// The following assertions and notifications are needed for restored
// endpoints. Fresh newly created endpoints have empty states and should
// not invoke any.
- e.segmentQueue.mu.Lock()
- if !e.segmentQueue.list.Empty() {
+ if !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
- e.segmentQueue.mu.Unlock()
e.rcvListMu.Lock()
if !e.rcvList.Empty() {
@@ -1371,28 +1411,53 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
+ cleanupOnError := func(err *tcpip.Error) {
+ e.mu.Lock()
+ e.workerCleanup = true
+ if err != nil {
+ e.resetConnectionLocked(err)
+ }
+ // Lock released below.
+ epilogue()
+ }
- for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+loop:
+ for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
- if err := funcs[v].f(); err != nil {
- e.mu.Lock()
- // Ensure we release all endpoint registration and route
- // references as the connection is now in an error
- // state.
- e.workerCleanup = true
- e.resetConnectionLocked(err)
- // Lock released below.
- epilogue()
+ // We need to double check here because the notification maybe
+ // stale by the time we got around to processing it.
+ //
+ // NOTE: since we now hold the workMu the processors cannot
+ // change the state of the endpoint so it's safe to proceed
+ // after this check.
+ switch e.EndpointState() {
+ case StateError:
+ // If the endpoint has already transitioned to an ERROR
+ // state just pass nil here as any reset that may need
+ // to be sent etc should already have been done and we
+ // just want to terminate the loop and cleanup the
+ // endpoint.
+ cleanupOnError(nil)
return nil
+ case StateTimeWait:
+ fallthrough
+ case StateClose:
+ e.mu.Lock()
+ break loop
+ default:
+ if err := funcs[v].f(); err != nil {
+ cleanupOnError(err)
+ return nil
+ }
+ e.mu.Lock()
}
- e.mu.Lock()
}
- state := e.state
+ state := e.EndpointState()
e.mu.Unlock()
var reuseTW func()
if state == StateTimeWait {
@@ -1405,13 +1470,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.Done()
// Wake up any waiters before we enter TIME_WAIT.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.mu.Lock()
+ e.workerCleanup = true
+ e.mu.Unlock()
reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
e.mu.Lock()
- if e.state != StateError {
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ if e.EndpointState() != StateError {
e.transitionToStateCloseLocked()
}
@@ -1468,7 +1535,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
- tcpEP.enqueueSegment(s)
+ if !tcpEP.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+ tcpEP.newSegmentWaker.Assert()
}
// We explicitly do not decRef
// the segment as it's still
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
new file mode 100644
index 000000000..e18012ac0
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -0,0 +1,224 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// epQueue is a queue of endpoints.
+type epQueue struct {
+ mu sync.Mutex
+ list endpointList
+}
+
+// enqueue adds e to the queue if the endpoint is not already on the queue.
+func (q *epQueue) enqueue(e *endpoint) {
+ q.mu.Lock()
+ if e.pendingProcessing {
+ q.mu.Unlock()
+ return
+ }
+ q.list.PushBack(e)
+ e.pendingProcessing = true
+ q.mu.Unlock()
+}
+
+// dequeue removes and returns the first element from the queue if available,
+// returns nil otherwise.
+func (q *epQueue) dequeue() *endpoint {
+ q.mu.Lock()
+ if e := q.list.Front(); e != nil {
+ q.list.Remove(e)
+ e.pendingProcessing = false
+ q.mu.Unlock()
+ return e
+ }
+ q.mu.Unlock()
+ return nil
+}
+
+// empty returns true if the queue is empty, false otherwise.
+func (q *epQueue) empty() bool {
+ q.mu.Lock()
+ v := q.list.Empty()
+ q.mu.Unlock()
+ return v
+}
+
+// processor is responsible for processing packets queued to a tcp endpoint.
+type processor struct {
+ epQ epQueue
+ newEndpointWaker sleep.Waker
+ id int
+}
+
+func newProcessor(id int) *processor {
+ p := &processor{
+ id: id,
+ }
+ go p.handleSegments()
+ return p
+}
+
+func (p *processor) queueEndpoint(ep *endpoint) {
+ // Queue an endpoint for processing by the processor goroutine.
+ p.epQ.enqueue(ep)
+ p.newEndpointWaker.Assert()
+}
+
+func (p *processor) handleSegments() {
+ const newEndpointWaker = 1
+ s := sleep.Sleeper{}
+ s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ defer s.Done()
+ for {
+ s.Fetch(true)
+ for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ if ep.segmentQueue.empty() {
+ continue
+ }
+
+ // If socket has transitioned out of connected state
+ // then just let the worker handle the packet.
+ //
+ // NOTE: We read this outside of e.mu lock which means
+ // that by the time we get to handleSegments the
+ // endpoint may not be in ESTABLISHED. But this should
+ // be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a
+ // CLOSED/ERROR state then handleSegments is a noop.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+
+ if !ep.workMu.TryLock() {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+ // If the endpoint is in a connected state then we do
+ // direct delivery to ensure low latency and avoid
+ // scheduler interactions.
+ if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
+ // Send any active resets if required.
+ if err != nil {
+ ep.mu.Lock()
+ ep.resetConnectionLocked(err)
+ ep.mu.Unlock()
+ }
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ ep.workMu.Unlock()
+ continue
+ }
+
+ if !ep.segmentQueue.empty() {
+ p.epQ.enqueue(ep)
+ }
+
+ ep.workMu.Unlock()
+ }
+ }
+}
+
+// dispatcher manages a pool of TCP endpoint processors which are responsible
+// for the processing of inbound segments. This fixed pool of processor
+// goroutines do full tcp processing. The processor is selected based on the
+// hash of the endpoint id to ensure that delivery for the same endpoint happens
+// in-order.
+type dispatcher struct {
+ processors []*processor
+ seed uint32
+}
+
+func newDispatcher(nProcessors int) *dispatcher {
+ processors := []*processor{}
+ for i := 0; i < nProcessors; i++ {
+ processors = append(processors, newProcessor(i))
+ }
+ return &dispatcher{
+ processors: processors,
+ seed: generateRandUint32(),
+ }
+}
+
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ ep := stackEP.(*endpoint)
+ s := newSegment(r, id, pkt)
+ if !s.parse() {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.ChecksumErrors.Increment()
+ ep.stats.ReceiveErrors.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ ep.stats.SegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ ep.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ if !ep.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+
+ // For sockets not in established state let the worker goroutine
+ // handle the packets.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ return
+ }
+
+ d.selectProcessor(id).queueEndpoint(ep)
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8)}
+
+ h := jenkins.Sum32(d.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+
+ return d.processors[h.Sum32()%uint32(len(d.processors))]
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index cc8b533c8..13718ff55 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -120,6 +120,7 @@ const (
notifyMTUChanged
notifyDrain
notifyReset
+ notifyResetByPeer
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -127,6 +128,7 @@ const (
// ensures the loop terminates if the final state of the endpoint is
// say TIME_WAIT.
notifyTickleWorker
+ notifyError
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {}
type endpoint struct {
EndpointInfo
+ // endpointEntry is used to queue endpoints for processing to the
+ // a given tcp processor goroutine.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ endpointEntry `state:"nosave"`
+
+ // pendingProcessing is true if this endpoint is queued for processing
+ // to a TCP processor.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ pendingProcessing bool `state:"nosave"`
+
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
// it, but other goroutines (e.g., send) may call TryLock() to eagerly
@@ -324,6 +338,7 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
+ // state must be read/set using the EndpointState()/setEndpointState() methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -359,7 +374,7 @@ type endpoint struct {
workerRunning bool
// workerCleanup specifies if the worker goroutine must perform cleanup
- // before exitting. This can only be set to true when workerRunning is
+ // before exiting. This can only be set to true when workerRunning is
// also true, and they're both protected by the mutex.
workerCleanup bool
@@ -371,6 +386,8 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
+ //
+ // recentTS must be read/written atomically.
recentTS uint32
// tsOffset is a randomized offset added to the value of the
@@ -567,6 +584,48 @@ func (e *endpoint) ResumeWork() {
e.workMu.Unlock()
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ switch state {
+ case StateEstablished:
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.stack.Stats().TCP.CurrentConnected.Increment()
+ case StateError:
+ fallthrough
+ case StateClose:
+ if oldstate == StateCloseWait || oldstate == StateEstablished {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ }
+ fallthrough
+ default:
+ if oldstate == StateEstablished {
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
+ }
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
+// setRecentTimestamp atomically sets the recentTS field to the
+// provided value.
+func (e *endpoint) setRecentTimestamp(recentTS uint32) {
+ atomic.StoreUint32(&e.recentTS, recentTS)
+}
+
+// recentTimestamp atomically reads and returns the value of the recentTS field.
+func (e *endpoint) recentTimestamp() uint32 {
+ return atomic.LoadUint32(&e.recentTS)
+}
+
// keepalive is a synchronization wrapper used to appease stateify. See the
// comment in endpoint, where it is used.
//
@@ -656,7 +715,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.mu.RLock()
defer e.mu.RUnlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
@@ -672,7 +731,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
}
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -733,14 +792,20 @@ func (e *endpoint) Close() {
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdown()
+}
+// closeNoShutdown closes the endpoint without doing a full shutdown. This is
+// used when a connection needs to be aborted with a RST and we want to skip
+// a full 4 way TCP shutdown.
+func (e *endpoint) closeNoShutdown() {
e.mu.Lock()
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
// in Listen() when trying to register.
- if e.state == StateListen && e.isPortReserved {
+ if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
@@ -780,6 +845,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
defer close(done)
for n := range e.acceptedChan {
n.notifyProtocolGoroutine(notifyReset)
+ // close all connections that have completed but
+ // not accepted by the application.
n.Close()
}
}()
@@ -797,11 +864,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
// after Close() is called and the worker goroutine (if any) is done with its
// work.
func (e *endpoint) cleanupLocked() {
+
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
e.closePendingAcceptableConnectionsLocked()
}
+
e.workerCleanup = false
if e.isRegistered {
@@ -920,7 +989,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
- if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
+ if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
e.mu.RUnlock()
@@ -944,7 +1013,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -980,8 +1049,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
- if !e.state.connected() {
- switch e.state {
+ if !e.EndpointState().connected() {
+ switch e.EndpointState() {
case StateError:
return 0, e.HardError
default:
@@ -1039,42 +1108,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, perr
}
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if opts.Atomic {
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+ }
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
+ if e.workMu.TryLock() {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
}
- }
-
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
-
- if e.workMu.TryLock() {
// Do the work inline.
e.handleWrite()
e.workMu.Unlock()
} else {
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+
+ }
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
@@ -1091,7 +1204,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
- if s := e.state; !s.connected() && s != StateClose {
+ if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
return 0, tcpip.ControlMessages{}, e.HardError
}
@@ -1103,7 +1216,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
@@ -1187,7 +1300,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1402,14 +1515,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Acquire the work mutex as we may need to
// reinitialize the congestion control state.
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.cc = v
e.mu.Unlock()
switch state {
case StateEstablished:
e.workMu.Lock()
e.mu.Lock()
- if e.state == state {
+ if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
e.mu.Unlock()
@@ -1472,7 +1585,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
defer e.mu.RUnlock()
// The endpoint cannot be in listen state.
- if e.state == StateListen {
+ if e.EndpointState() == StateListen {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1731,7 +1844,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return err
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// The endpoint is already connected. If caller hasn't been
// notified yet, return success.
if !e.isConnectNotified {
@@ -1743,7 +1856,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
nicID := addr.NIC
- switch e.state {
+ switch e.EndpointState() {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
@@ -1850,7 +1963,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.isRegistered = true
- e.state = StateConnecting
+ e.setEndpointState(StateConnecting)
e.route = r.Clone()
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1871,14 +1984,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
- e.state = StateEstablished
- e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.setEndpointState(StateEstablished)
}
if run {
e.workerRunning = true
e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
return tcpip.ErrConnectStarted
@@ -1896,7 +2008,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.shutdownFlags |= flags
finQueued := false
switch {
- case e.state.connected():
+ case e.EndpointState().connected():
// Close for read.
if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
// Mark read side as closed.
@@ -1908,8 +2020,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.notifyProtocolGoroutine(notifyReset)
e.mu.Unlock()
+ // Try to send an active reset immediately if the
+ // work mutex is available.
+ if e.workMu.TryLock() {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ } else {
+ e.notifyProtocolGoroutine(notifyReset)
+ }
return nil
}
}
@@ -1931,11 +2053,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
-
e.sndBufMu.Unlock()
}
- case e.state == StateListen:
+ case e.EndpointState() == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
@@ -1976,7 +2097,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// When the endpoint shuts down, it sets workerCleanup to true, and from
// that point onward, acceptedChan is the responsibility of the cleanup()
// method (and should not be touched anywhere else, including here).
- if e.state == StateListen && !e.workerCleanup {
+ if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
if len(e.acceptedChan) > backlog {
@@ -1994,7 +2115,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
return nil
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
// The listen is called on an unbound socket, the socket is
// automatically bound to a random free port with the local
// address set to INADDR_ANY.
@@ -2004,7 +2125,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Endpoint must be bound before it can transition to listen mode.
- if e.state != StateBound {
+ if e.EndpointState() != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
@@ -2015,24 +2136,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
e.isRegistered = true
- e.state = StateListen
+ e.setEndpointState(StateListen)
+
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
e.workerRunning = true
-
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
-
return nil
}
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.mu.Lock()
e.waiterQueue = waiterQueue
e.workerRunning = true
- go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+ e.mu.Unlock()
+ wakerInitDone := make(chan struct{})
+ go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save.
+ <-wakerInitDone
}
// Accept returns a new endpoint if a peer has established a connection
@@ -2042,7 +2166,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
- if e.state != StateListen {
+ if e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
@@ -2069,7 +2193,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrAlreadyBound
}
@@ -2131,7 +2255,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
return nil
}
@@ -2153,7 +2277,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if !e.state.connected() {
+ if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -2164,45 +2288,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
- s := newSegment(r, id, pkt)
- if !s.parse() {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
- e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
- s.decRef()
- return
- }
-
- if !s.csumValid {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.ChecksumErrors.Increment()
- e.stats.ReceiveErrors.ChecksumErrors.Increment()
- s.decRef()
- return
- }
-
- e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
- e.stats.SegmentsReceived.Increment()
- if (s.flags & header.TCPFlagRst) != 0 {
- e.stack.Stats().TCP.ResetsReceived.Increment()
- }
-
- e.enqueueSegment(s)
+ // TCP HandlePacket is not required anymore as inbound packets first
+ // land at the Dispatcher which then can either delivery using the
+ // worker go routine or directly do the invoke the tcp processing inline
+ // based on the state of the endpoint.
}
-func (e *endpoint) enqueueSegment(s *segment) {
+func (e *endpoint) enqueueSegment(s *segment) bool {
// Send packet to worker goroutine.
- if e.segmentQueue.enqueue(s) {
- e.newSegmentWaker.Assert()
- } else {
+ if !e.segmentQueue.enqueue(s) {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
- s.decRef()
+ return false
}
+ return true
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
@@ -2319,8 +2420,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int {
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
- if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
- e.recentTS = tsVal
+ if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.setRecentTimestamp(tsVal)
}
}
@@ -2330,7 +2431,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
e.sendTSOk = true
- e.recentTS = synOpts.TSVal
+ e.setRecentTimestamp(synOpts.TSVal)
}
}
@@ -2419,7 +2520,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Endpoint TCP Option state.
s.SendTSOk = e.sendTSOk
- s.RecentTS = e.recentTS
+ s.RecentTS = e.recentTimestamp()
s.TSOffset = e.tsOffset
s.SACKPermitted = e.sackPermitted
s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
@@ -2526,9 +2627,7 @@ func (e *endpoint) initGSO() {
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 4b8d867bc..4a46f0ec5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -16,6 +16,7 @@ package tcp
import (
"fmt"
+ "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound:
// TODO(b/138137272): this enumeration duplicates
// EndpointState.connected. remove it.
@@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() {
fallthrough
case StateListen, StateConnecting:
e.drainSegmentLocked()
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
break
}
- fallthrough
case StateError, StateClose:
- for (e.state == StateError || e.state == StateClose) && e.workerRunning {
+ for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
}
if e.workerRunning {
- panic("endpoint still has worker running in closed or error state")
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
}
default:
- panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
}
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
- if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) {
+ if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
}
}
@@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
// saveState is invoked by stateify.
func (e *endpoint) saveState() EndpointState {
- return e.state
+ return e.EndpointState()
}
// Endpoint loading must be done in the following ordering by their state, to
@@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup
func (e *endpoint) loadState(state EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
- if state.connected() {
+ // For restore purposes we treat TimeWait like a connected endpoint.
+ if state.connected() || state == StateTimeWait {
connectedLoading.Add(1)
}
switch state {
@@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) {
case StateConnecting, StateSynSent, StateSynRecv:
connectingLoading.Add(1)
}
- e.state = state
+ // Directly update the state here rather than using e.setEndpointState
+ // as the endpoint is still being loaded and the stack reference to increment
+ // metrics is not yet initialized.
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
}
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- // Freeze segment queue before registering to prevent any segments
- // from being delivered while it is being restored.
e.origEndpointState = e.state
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
@@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
state := e.origEndpointState
-
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Wait()
connectingLoading.Wait()
bind()
- e.state = StateClose
+ e.setEndpointState(StateClose)
tcpip.AsyncLoading.Done()
}()
}
@@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
+
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 9a8f64aa6..958c06fa7 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -21,6 +21,7 @@
package tcp
import (
+ "runtime"
"strings"
"time"
@@ -104,6 +105,7 @@ type protocol struct {
moderateReceiveBuffer bool
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
+ dispatcher *dispatcher
}
// Number returns the tcp protocol number.
@@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
+// QueuePacket queues packets targeted at an endpoint after hashing the packet
+// to a specific processing queue. Each queue is serviced by its own processor
+// goroutine which is responsible for dequeuing and doing full TCP dispatch of
+// the packet.
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ p.dispatcher.queuePacket(r, ep, id, pkt)
+}
+
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
//
@@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 05c8488f8..958f03ac1 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -169,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateEstablished:
- r.ep.state = StateCloseWait
+ r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
if s.flagIsSet(header.TCPFlagAck) {
// FIN-ACK, transition to TIME-WAIT.
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
} else {
// Simultaneous close, expecting a final ACK.
- r.ep.state = StateClosing
+ r.ep.setEndpointState(StateClosing)
}
case StateFinWait2:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
}
r.ep.mu.Unlock()
@@ -205,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateFinWait1:
- r.ep.state = StateFinWait2
+ r.ep.setEndpointState(StateFinWait2)
// Notify protocol goroutine that we have received an
// ACK to our FIN so that it can start the FIN_WAIT2
// timer to abort connection if the other side does
// not close within 2MSL.
r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
case StateLastAck:
r.ep.transitionToStateCloseLocked()
}
@@ -267,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
switch state {
case StateCloseWait, StateClosing, StateLastAck:
if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
- s.decRef()
// Just drop the segment as we have
// already received a FIN and this
// segment is after the sequence number
@@ -284,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
if state == StateFinWait1 {
@@ -314,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// the last actual data octet in a segment in
// which it occurs.
if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
}
@@ -336,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
r.ep.mu.RLock()
- state := r.ep.state
+ state := r.ep.EndpointState()
closed := r.ep.closed
r.ep.mu.RUnlock()
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index fdff7ed81..b74b61e7d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -705,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
seg.flags = header.TCPFlagAck | header.TCPFlagFin
segEnd = seg.sequenceNumber.Add(1)
- // Transition to FIN-WAIT1 state since we're initiating an active close.
- s.ep.mu.Lock()
- switch s.ep.state {
+ // Update the state to reflect that we have now
+ // queued a FIN.
+ switch s.ep.EndpointState() {
case StateCloseWait:
- // We've already received a FIN and are now sending our own. The
- // sender is now awaiting a final ACK for this FIN.
- s.ep.state = StateLastAck
+ s.ep.setEndpointState(StateLastAck)
default:
- s.ep.state = StateFinWait1
+ s.ep.setEndpointState(StateFinWait1)
}
- s.ep.mu.Unlock()
+
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6edfa8dce..df2fb1071 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SeqNum(uint32(c.IRS+1)),
checker.AckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
finHeaders := &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence number
+ // of the FIN itself.
checker.SeqNum(uint32(c.IRS)+2),
checker.AckNum(0),
checker.TCPFlags(header.TCPFlagRst),
@@ -468,6 +470,89 @@ func TestConnectResetAfterClose(t *testing.T) {
}
}
+// TestCurrentConnectedIncrement tests increment of the current
+// established and connected counters.
+func TestCurrentConnectedIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got)
+ }
+ gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
+ if gotConnected != 1 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected)
+ }
+
+ ep.Close()
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected)
+ }
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Wait for the TIME-WAIT state to transition to CLOSED.
+ time.Sleep(1 * time.Second)
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got)
+ }
+}
+
// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
// when the endpoint transitions to StateClose. The in-flight segments would be
// re-enqueued to a any listening endpoint.
@@ -1500,6 +1585,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence
+ // number of the FIN itself.
checker.SeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
@@ -5441,6 +5529,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
packetsSent++
}
+
// Resume the worker so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
@@ -5508,7 +5597,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
stk := c.Stack()
// Set lower limits for auto-tuning tests. This is required because the
// test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 500.
+ // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
@@ -5563,6 +5652,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalSent += mss
packetsSent++
}
+
// Resume it so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index b33ec2087..ce6a2c31d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,7 +6,6 @@ go_library(
name = "context",
testonly = 1,
srcs = ["context.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
"//visibility:public",
],
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 822907998..730ac4292 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -18,6 +18,7 @@ package context
import (
"bytes"
+ "context"
"testing"
"time"
@@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack {
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
- select {
- case <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), wait)
+ if _, ok := c.linkEP.ReadContext(ctx); ok {
c.t.Fatal(errMsg)
-
- case <-time.After(wait):
}
}
@@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) {
// 2 seconds.
func (c *Context) GetPacket() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
- if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
- }
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
+ if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
+ c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
}
- return nil
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// GetPacketNonBlocking reads a packet from the link layer endpoint
@@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte {
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
- default:
+ p, ok := c.linkEP.Read()
+ if !ok {
return nil
}
+
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
@@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
-
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
- return nil
+ checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
+ return b
}
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 43fcc27f0..3ad6994a7 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "tcpconntrack",
srcs = ["tcp_conntrack.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 57ff123e3..adc908e24 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -25,7 +24,6 @@ go_library(
"protocol.go",
"udp_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 13446f5d9..c9cbed8f4 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -31,6 +31,7 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -113,6 +114,10 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
+ // receiveTOS determines if the incoming IPv4 TOS header field is passed
+ // as ancillary data to ControlMessages on Read.
+ receiveTOS bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -243,7 +248,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
*addr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ cm := tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ }
+ e.mu.RLock()
+ receiveTOS := e.receiveTOS
+ e.mu.RUnlock()
+ if receiveTOS {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ return p.data.ToView(), cm, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -458,6 +474,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.Lock()
+ e.receiveTOS = v
+ e.mu.Unlock()
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -664,15 +686,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.RLock()
+ v := e.receiveTOS
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
return false, tcpip.ErrUnknownProtocolOption
}
- e.mu.Lock()
+ e.mu.RLock()
v := e.v6only
- e.mu.Unlock()
+ e.mu.RUnlock()
return v, nil
}
@@ -1215,6 +1243,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += pkt.Data.Size()
+ // Save any useful information from the network header to the packet.
+ switch r.NetProto {
+ case header.IPv4ProtocolNumber:
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ }
+
packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 0a82bc4fa..f0ff3fe71 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "context"
"fmt"
"math/rand"
"testing"
@@ -56,6 +57,7 @@ const (
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
+ testTOS = 0x80
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -273,11 +275,16 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
-
- s := stack.New(stack.Options{
+ return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
})
+}
+
+func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
+ t.Helper()
+
+ s := stack.New(options)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -351,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- h := flow.header4Tuple(outgoing)
- checkers := append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-
- case <-time.After(2 * time.Second):
+ ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
}
- return nil
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
+ }
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
+
+ h := flow.header4Tuple(outgoing)
+ checkers = append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
+ return b
}
// injectPacket creates a packet of the given flow and with the given payload,
@@ -453,6 +459,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
+ TOS: testTOS,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
@@ -552,8 +559,8 @@ func TestBindToDeviceOption(t *testing.T) {
// testReadInternal sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
+// correctness including any additional checker functions provided.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
payload := newPayload()
@@ -568,12 +575,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
+ v, cm, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, _, err = c.ep.Read(&addr)
+ v, cm, err = c.ep.Read(&addr)
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -606,15 +613,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+
+ // Run any checkers against the ControlMessages.
+ for _, f := range checkers {
+ f(c.t, cm)
+ }
+
c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// its correctness including any additional checker functions provided.
+func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
}
// testFailingRead sends a packet of the given test flow into the stack by
@@ -755,6 +768,49 @@ func TestV6ReadOnV6(t *testing.T) {
testRead(c, unicastV6)
}
+// TestV4ReadSelfSource checks that packets coming from a local IP address are
+// correctly dropped when handleLocal is true and not otherwise.
+func TestV4ReadSelfSource(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ handleLocal bool
+ wantErr *tcpip.Error
+ wantInvalidSource uint64
+ }{
+ {"HandleLocal", false, nil, 0},
+ {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ HandleLocal: tt.handleLocal,
+ })
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ h.srcAddr = h.dstAddr
+
+ c.injectV4Packet(payload, &h, true /* valid */)
+
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
+ t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
+ }
+
+ if _, _, err := c.ep.Read(nil); err != tt.wantErr {
+ t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1228,7 +1284,10 @@ func TestTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
if err != nil {
t.Fatal(err)
}
@@ -1261,7 +1320,10 @@ func TestSetTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
if err != nil {
t.Fatal(err)
}
@@ -1282,7 +1344,7 @@ func TestTOSV4(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1317,7 +1379,7 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1344,6 +1406,47 @@ func TestTOSV6(t *testing.T) {
}
}
+func TestReceiveTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
+ }
+
+ want := true
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
+ }
+
+ got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ if got != want {
+ c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
+ }
+
+ // Verify that the correct received TOS is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1438,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
+
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
- }
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
}
})
}
@@ -1512,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ var pkt []byte
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
}
})
}