summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go44
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go65
-rw-r--r--pkg/tcpip/header/ipv4.go9
-rw-r--r--pkg/tcpip/header/ipv6.go9
-rw-r--r--pkg/tcpip/link/rawfile/BUILD3
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s2
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_stub_unsafe.go)0
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_unsafe.go45
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go66
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/network/arp/arp_test.go4
-rw-r--r--pkg/tcpip/network/ip_test.go6
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go28
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go33
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD1
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go9
-rw-r--r--pkg/tcpip/stack/nic.go280
-rw-r--r--pkg/tcpip/stack/route.go10
-rw-r--r--pkg/tcpip/stack/stack.go2
-rw-r--r--pkg/tcpip/stack/stack_test.go661
-rw-r--r--pkg/tcpip/stack/transport_test.go53
-rw-r--r--pkg/tcpip/tcpip.go87
-rw-r--r--pkg/tcpip/tcpip_test.go32
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go44
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go29
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go43
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go26
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go207
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go102
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go81
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go50
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go1104
36 files changed, 1977 insertions, 1185 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index c40924852..0d2637ee4 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -24,6 +24,7 @@ go_test(
embed = [":gonet"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 308f620e5..cd6ce930a 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -404,7 +404,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
}
- var n uintptr
+ var n int64
var resCh <-chan struct{}
n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
nbytes += int(n)
@@ -556,32 +556,50 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
- // Create UDP endpoint and bind it.
+// DialUDP creates a new PacketConn.
+//
+// If laddr is nil, a local address is automatically chosen.
+//
+// If raddr is nil, the PacketConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
- if err := ep.Bind(addr); err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "bind",
- Net: "udp",
- Addr: fullToUDPAddr(addr),
- Err: errors.New(err.String()),
+ if laddr != nil {
+ if err := ep.Bind(*laddr); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "udp",
+ Addr: fullToUDPAddr(*laddr),
+ Err: errors.New(err.String()),
+ }
}
}
- c := &PacketConn{
+ c := PacketConn{
stack: s,
ep: ep,
wq: &wq,
}
c.deadlineTimer.init()
- return c, nil
+
+ if raddr != nil {
+ if err := c.ep.Connect(*raddr); err != nil {
+ c.ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "udp",
+ Addr: fullToUDPAddr(*raddr),
+ Err: errors.New(err.String()),
+ }
+ }
+ }
+
+ return &c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 39efe44c7..672f026b2 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -26,6 +26,7 @@ import (
"golang.org/x/net/nettest"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -69,17 +70,13 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
s.SetRouteTable([]tcpip.Route{
// IPv4
{
- Destination: tcpip.Address(strings.Repeat("\x00", 4)),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", 4)),
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: NICID,
},
// IPv6
{
- Destination: tcpip.Address(strings.Repeat("\x00", 16)),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", 16)),
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: NICID,
},
})
@@ -371,9 +368,9 @@ func TestUDPForwarder(t *testing.T) {
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
sent := "abc123"
@@ -452,13 +449,13 @@ func TestPacketConnTransfer(t *testing.T) {
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
- c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber)
+ c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 4):", err)
+ t.Fatal("DialUDP(bind port 4):", err)
}
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
c1.SetDeadline(time.Now().Add(time.Second))
@@ -491,6 +488,50 @@ func TestPacketConnTransfer(t *testing.T) {
}
}
+func TestConnectedPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 4):", err)
+ }
+ c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, err := c1.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
func makePipe() (c1, c2 net.Conn, stop func(), err error) {
s, e := newLoopbackStack()
if e != nil {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 94a3af289..17fc9c68e 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -111,6 +111,15 @@ const (
IPv4FlagDontFragment
)
+// IPv4EmptySubnet is the empty IPv4 subnet.
+var IPv4EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPVersion returns the version of IP used in the given packet. It returns -1
// if the packet is not large enough to contain the version field.
func IPVersion(b []byte) int {
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 95fe8bfc3..31be42ce0 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -82,6 +82,15 @@ const (
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
)
+// IPv6EmptySubnet is the empty IPv6 subnet.
+var IPv6EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 11ba31ca4..088eb8a21 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -7,8 +7,9 @@ go_library(
srcs = [
"blockingpoll_amd64.s",
"blockingpoll_arm64.s",
+ "blockingpoll_noyield_unsafe.go",
"blockingpoll_unsafe.go",
- "blockingpoll_stub_unsafe.go",
+ "blockingpoll_yield_unsafe.go",
"errors.go",
"rawfile_unsafe.go",
],
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
index 0bc873a01..b62888b93 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_stub_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
index 621ab8d29..621ab8d29 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_stub_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
index eeca47d78..84dc0e918 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
@@ -12,49 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64 linux,arm64
-// +build go1.12
-// +build !go1.14
-
-// Check go:linkname function signatures when updating Go version.
+// +build linux,!amd64
package rawfile
import (
"syscall"
- _ "unsafe" // for go:linkname
+ "unsafe"
)
-//go:noescape
-func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
-
-// Use go:linkname to call into the runtime. As of Go 1.12 this has to
-// be done from Go code so that we make an ABIInternal call to an
-// ABIInternal function; see https://golang.org/issue/27539.
-
-// We need to call both entersyscallblock and exitsyscall this way so
-// that the runtime's check on the stack pointer lines up.
-
-// Note that calling an unexported function in the runtime package is
-// unsafe and this hack is likely to break in future Go releases.
-
-//go:linkname entersyscallblock runtime.entersyscallblock
-func entersyscallblock()
-
-//go:linkname exitsyscall runtime.exitsyscall
-func exitsyscall()
-
-// These forwarding functions must be nosplit because 1) we must
-// disallow preemption between entersyscallblock and exitsyscall, and
-// 2) we have an untyped assembly frame on the stack which can not be
-// grown or moved.
-
-//go:nosplit
-func callEntersyscallblock() {
- entersyscallblock()
-}
+// BlockingPoll is just a stub function that forwards to the ppoll() system call
+// on non-amd64 platforms.
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
+ uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
-//go:nosplit
-func callExitsyscall() {
- exitsyscall()
+ return int(n), e
}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
new file mode 100644
index 000000000..dda3b10a6
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -0,0 +1,66 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,amd64 linux,arm64
+// +build go1.12
+// +build !go1.14
+
+// Check go:linkname function signatures when updating Go version.
+
+package rawfile
+
+import (
+ "syscall"
+ _ "unsafe" // for go:linkname
+)
+
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
+//go:noescape
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
+
+// Use go:linkname to call into the runtime. As of Go 1.12 this has to
+// be done from Go code so that we make an ABIInternal call to an
+// ABIInternal function; see https://golang.org/issue/27539.
+
+// We need to call both entersyscallblock and exitsyscall this way so
+// that the runtime's check on the stack pointer lines up.
+
+// Note that calling an unexported function in the runtime package is
+// unsafe and this hack is likely to break in future Go releases.
+
+//go:linkname entersyscallblock runtime.entersyscallblock
+func entersyscallblock()
+
+//go:linkname exitsyscall runtime.exitsyscall
+func exitsyscall()
+
+// These forwarding functions must be nosplit because 1) we must
+// disallow preemption between entersyscallblock and exitsyscall, and
+// 2) we have an untyped assembly frame on the stack which can not be
+// grown or moved.
+
+//go:nosplit
+func callEntersyscallblock() {
+ entersyscallblock()
+}
+
+//go:nosplit
+func callExitsyscall() {
+ exitsyscall()
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index fc584c6a4..36c8c46fc 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -360,10 +360,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
+ details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
+ size -= header.UDPMinimumSize
}
- size -= header.UDPMinimumSize
-
- details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
case header.TCPProtocolNumber:
transName = "tcp"
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index e477046db..4c4b54469 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -66,9 +66,7 @@ func newTestContext(t *testing.T) *testContext {
}
s.SetRouteTable([]tcpip.Route{{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
}})
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 55e9eec99..6bbfcd97f 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -173,8 +173,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
- Destination: ipv4SubnetAddr,
- Mask: ipv4SubnetMask,
+ Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
NIC: 1,
}})
@@ -187,8 +186,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
- Destination: ipv6SubnetAddr,
- Mask: ipv6SubnetMask,
+ Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
NIC: 1,
}})
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 3207a3d46..1b5a55bea 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -52,9 +52,7 @@ func TestExcludeBroadcast(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
}})
@@ -247,14 +245,22 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
_, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
linkEPId := stack.RegisterLinkEndpoint(linkEP)
s.CreateNIC(1, linkEPId)
- s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01")
- s.SetRouteTable([]tcpip.Route{{
- Destination: "\x10\x00\x00\x02",
- Mask: "\xff\xff\xff\xff",
- Gateway: "",
- NIC: 1,
- }})
- r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */)
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ s.AddAddress(1, ipv4.ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("s.FindRoute got %v, want %v", err, nil)
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 726362c87..d0dc72506 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -91,13 +91,18 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
}
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: lladdr1,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
- NIC: 1,
- }},
- )
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
netProto := s.NetworkProtocolInstance(ProtocolNumber)
if netProto == nil {
@@ -237,17 +242,23 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress sn lladdr1: %v", err)
}
+ subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
c.s0.SetRouteTable(
[]tcpip.Route{{
- Destination: lladdr1,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
+ Destination: subnet0,
NIC: 1,
}},
)
+ subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
c.s1.SetRouteTable(
[]tcpip.Route{{
- Destination: lladdr0,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
+ Destination: subnet1,
NIC: 1,
}},
)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index 996939581..a57752a7c 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -8,6 +8,7 @@ go_binary(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/rawfile",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 3ac381631..e2021cd15 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -52,6 +52,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -152,9 +153,7 @@ func main() {
// Add default route.
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index da425394a..1716be285 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -149,12 +149,15 @@ func main() {
log.Fatal(err)
}
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ if err != nil {
+ log.Fatal(err)
+ }
+
// Add default route.
s.SetRouteTable([]tcpip.Route{
{
- Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))),
- Gateway: "",
+ Destination: subnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index dc28dc970..04b63d783 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add
if list, ok := n.primary[protocol]; ok {
for e := list.Front(); e != nil; e = e.Next() {
ref := e.(*referencedNetworkEndpoint)
- if ref.holdsInsertRef && ref.tryIncRef() {
+ if ref.kind == permanent && ref.tryIncRef() {
r = ref
break
}
@@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
case header.IPv4Broadcast, header.IPv4Any:
continue
}
- if r.tryIncRef() {
+ if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
}
@@ -186,46 +186,124 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
+}
+
+// getRefEpOrCreateTemp returns the referenced network endpoint for the given
+// protocol and address. If none exists a temporary one may be created if
+// we are in promiscuous mode or spoofing.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- ref := n.endpoints[id]
- if ref != nil && !ref.tryIncRef() {
- ref = nil
+
+ if ref, ok := n.endpoints[id]; ok {
+ // An endpoint with this id exists, check if it can be used and return it.
+ switch ref.kind {
+ case permanentExpired:
+ if !spoofingOrPromiscuous {
+ n.mu.RUnlock()
+ return nil
+ }
+ fallthrough
+ case temporary, permanent:
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
}
- spoofing := n.spoofing
+
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
+ if !createTempEP {
+ for _, sn := range n.subnets {
+ if sn.Contains(address) {
+ createTempEP = true
+ break
+ }
+ }
+ }
+
n.mu.RUnlock()
- if ref != nil || !spoofing {
- return ref
+ if !createTempEP {
+ return nil
}
// Try again with the lock in exclusive mode. If we still can't get the
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- ref = n.endpoints[id]
- if ref == nil || !ref.tryIncRef() {
- if netProto, ok := n.stack.networkProtocols[protocol]; ok {
- ref, _ = n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb, true)
- if ref != nil {
- ref.holdsInsertRef = false
- }
+ if ref, ok := n.endpoints[id]; ok {
+ // No need to check the type as we are ok with expired endpoints at this
+ // point.
+ if ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
}
+ // tryIncRef failing means the endpoint is scheduled to be removed once the
+ // lock is released. Remove it here so we can create a new (temporary) one.
+ // The removal logic waiting for the lock handles this case.
+ n.removeEndpointLocked(ref)
}
+
+ // Add a new temporary endpoint.
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ n.mu.Unlock()
+ return nil
+ }
+ ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb, temporary)
+
n.mu.Unlock()
return ref
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if ref, ok := n.endpoints[id]; ok {
+ switch ref.kind {
+ case permanent:
+ // The NIC already have a permanent endpoint with that address.
+ return nil, tcpip.ErrDuplicateAddress
+ case permanentExpired, temporary:
+ // Promote the endpoint to become permanent.
+ if ref.tryIncRef() {
+ ref.kind = permanent
+ return ref, nil
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once
+ // the lock is released. Remove it here so we can create a new
+ // (permanent) one. The removal logic waiting for the lock handles this
+ // case.
+ n.removeEndpointLocked(ref)
+ }
+ }
+ return n.addAddressLocked(protocolAddress, peb, permanent)
+}
+
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // Sanity check.
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if _, ok := n.endpoints[id]; ok {
+ // Endpoint already exists.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
@@ -236,22 +314,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if err != nil {
return nil, err
}
-
- id := *ep.ID()
- if ref, ok := n.endpoints[id]; ok {
- if !replace {
- return nil, tcpip.ErrDuplicateAddress
- }
-
- n.removeEndpointLocked(ref)
- }
-
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- holdsInsertRef: true,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
}
// Set up cache if link address resolution exists for this protocol.
@@ -284,7 +352,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocolAddress, peb, false)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb)
n.mu.Unlock()
return err
@@ -296,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
defer n.mu.RUnlock()
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.kind {
+ case permanentExpired, temporary:
+ continue
+ }
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -361,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet {
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
- // Nothing to do if the reference has already been replaced with a
- // different one.
+ // Nothing to do if the reference has already been replaced with a different
+ // one. This happens in the case where 1) this endpoint's ref count hit zero
+ // and was waiting (on the lock) to be removed and 2) the same address was
+ // re-added in the meantime by removing this endpoint from the list and
+ // adding a new one.
if n.endpoints[id] != r {
return
}
- if r.holdsInsertRef {
+ if r.kind == permanent {
panic("Reference count dropped to zero before being removed")
}
@@ -386,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
+func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
r := n.endpoints[NetworkEndpointID{addr}]
- if r == nil || !r.holdsInsertRef {
+ if r == nil || r.kind != permanent {
return tcpip.ErrBadLocalAddress
}
- r.holdsInsertRef = false
-
+ r.kind = permanentExpired
r.decRefLocked()
return nil
@@ -403,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.removeAddressLocked(addr)
+ return n.removePermanentAddressLocked(addr)
}
// joinGroup adds a new endpoint for the given multicast address, if none
@@ -419,13 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
if !ok {
return tcpip.ErrUnknownProtocol
}
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint, false); err != nil {
+ }, NeverPrimaryEndpoint); err != nil {
return err
}
}
@@ -447,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
case 1:
// This is the last one, clean up.
- if err := n.removeAddressLocked(addr); err != nil {
+ if err := n.removePermanentAddressLocked(addr); err != nil {
return err
}
}
@@ -489,7 +565,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
// n.endpoints is mutex protected so acquire lock.
n.mu.RLock()
for _, ref := range n.endpoints {
- if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
+ if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remote
ref.ep.HandlePacket(&r, vv)
@@ -527,8 +603,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
n.mu.RUnlock()
- if ok && ref.tryIncRef() {
+ if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, vv)
@@ -553,57 +630,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
-func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- id := NetworkEndpointID{dst}
-
- n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
-
- promiscuous := n.promiscuous
- // Check if the packet is for a subnet this NIC cares about.
- if !promiscuous {
- for _, sn := range n.subnets {
- if sn.Contains(dst) {
- promiscuous = true
- break
- }
- }
- }
- n.mu.RUnlock()
- if promiscuous {
- // Try again with the lock in exclusive mode. If we still can't
- // get the endpoint, create a new "temporary" one. It will only
- // exist while there's a route through it.
- n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.Unlock()
- return ref
- }
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- n.mu.Unlock()
- return nil
- }
- ref, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: dst,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, CanBePrimaryEndpoint, true)
- n.mu.Unlock()
- if err == nil {
- ref.holdsInsertRef = false
- return ref
- }
- }
-
- return nil
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
@@ -691,9 +717,33 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+type networkEndpointKind int
+
+const (
+ // A permanent endpoint is created by adding a permanent address (vs. a
+ // temporary one) to the NIC. Its reference count is biased by 1 to avoid
+ // removal when no route holds a reference to it. It is removed by explicitly
+ // removing the permanent address from the NIC.
+ permanent networkEndpointKind = iota
+
+ // An expired permanent endoint is a permanent endoint that had its address
+ // removed from the NIC, and it is waiting to be removed once no more routes
+ // hold a reference to it. This is achieved by decreasing its reference count
+ // by 1. If its address is re-added before the endpoint is removed, its type
+ // changes back to permanent and its reference count increases by 1 again.
+ permanentExpired
+
+ // A temporary endpoint is created for spoofing outgoing packets, or when in
+ // promiscuous mode and accepting incoming packets that don't match any
+ // permanent endpoint. Its reference count is not biased by 1 and the
+ // endpoint is removed immediately when no more route holds a reference to
+ // it. A temporary endpoint can be promoted to permanent if its address
+ // is added permanently.
+ temporary
+)
+
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -702,11 +752,25 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
- // holdsInsertRef is protected by the NIC's mutex. It indicates whether
- // the reference count is biased by 1 due to the insertion of the
- // endpoint. It is reset to false when RemoveAddress is called on the
- // NIC.
- holdsInsertRef bool
+ // refs is counting references held for this endpoint. When refs hits zero it
+ // triggers the automatic removal of the endpoint from the NIC.
+ refs int32
+
+ kind networkEndpointKind
+}
+
+// isValidForOutgoing returns true if the endpoint can be used to send out a
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in spoofing mode.
+func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
+ return r.kind != permanentExpired || r.nic.spoofing
+}
+
+// isValidForIncoming returns true if the endpoint can accept an incoming
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in promiscuous mode.
+func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
+ return r.kind != permanentExpired || r.nic.promiscuous
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 391ab4344..e52cdd674 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
func (r *Route) IsResolutionRequired() bool {
- return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
@@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index d45e547ee..d69162ba1 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -895,7 +895,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
} else {
for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) {
continue
}
if nic, ok := s.nics[route.NIC]; ok {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 1ab9c575b..4debd1eec 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -181,6 +181,10 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
+func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
+ return f.packetCount[int(intfAddr)%len(f.packetCount)]
+}
+
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
@@ -289,16 +293,75 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ return err
}
defer r.Release()
+ return send(r, payload)
+}
+func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Error("WritePacket failed:", err)
+ return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123)
+}
+
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ linkEP.Drain()
+ if err := sendTo(s, addr, payload); err != nil {
+ t.Error("sendTo failed:", err)
+ }
+ if got, want := linkEP.Drain(), 1; got != want {
+ t.Errorf("sendTo packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ linkEP.Drain()
+ if err := send(r, payload); err != nil {
+ t.Error("send failed:", err)
+ }
+ if got, want := linkEP.Drain(), 1; got != want {
+ t.Errorf("send packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := send(r, payload); gotErr != wantErr {
+ t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
+ t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte) + 1
+ testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+}
+
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we do NOT expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte)
+ testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+}
+
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) {
+ t.Helper()
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if got := fakeNet.PacketCount(localAddrByte); got != want {
+ t.Errorf("receive packet count: got = %d, want %d", got, want)
}
}
@@ -312,17 +375,20 @@ func TestNetworkSend(t *testing.T) {
t.Fatal("NewNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
}
// Make sure that the link-layer endpoint received the outbound packet.
- sendTo(t, s, "\x03", nil)
- if c := linkEP.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x03", linkEP, nil)
}
func TestNetworkSendMultiRoute(t *testing.T) {
@@ -360,24 +426,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
// Send a packet to an odd destination.
- sendTo(t, s, "\x05", nil)
-
- if c := linkEP1.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x05", linkEP1, nil)
// Send a packet to an even destination.
- sendTo(t, s, "\x06", nil)
-
- if c := linkEP2.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x06", linkEP2, nil)
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
@@ -439,10 +507,20 @@ func TestRoutes(t *testing.T) {
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
// Test routes to odd address.
testRoute(t, s, 0, "", "\x05", "\x01")
@@ -472,6 +550,10 @@ func TestRoutes(t *testing.T) {
}
func TestAddressRemoval(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
@@ -479,99 +561,285 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
- // Remove the address, then check that packet doesn't get delivered
- // anymore.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
}
-func TestDelayedRemovalDueToRoute(t *testing.T) {
+func TestAddressRemovalWithRouteHeld(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatal("CreateNIC failed:", err)
}
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- // Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 2 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSend(t, r, linkEP, nil)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
- // Remove the address, then check that packet is still deliverable
- // because the route is keeping the address alive.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
+}
- // Release the route, then check that packet is not deliverable anymore.
- r.Release()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+ t.Helper()
+ info, ok := s.NICInfo()[nicid]
+ if !ok {
+ t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ }
+ if len(addr) == 0 {
+ // No address given, verify that there is no address assigned to the NIC.
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ }
+ }
+ return
+ }
+ // Address given, verify the address is assigned to the NIC and no other
+ // address is.
+ found := false
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber {
+ if a.AddressWithPrefix.Address == addr {
+ found = true
+ } else {
+ t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
+ }
+ }
+ }
+ if !found {
+ t.Errorf("verify address: couldn't find %s on the NIC", addr)
+ }
+}
+
+func TestEndpointExpiration(t *testing.T) {
+ const (
+ localAddrByte byte = 0x01
+ remoteAddr tcpip.Address = "\x03"
+ noAddr tcpip.Address = ""
+ nicid tcpip.NICID = 1
+ )
+ localAddr := tcpip.Address([]byte{localAddrByte})
+
+ for _, promiscuous := range []bool{true, false} {
+ for _, spoofing := range []bool{true, false} {
+ t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+ buf[0] = localAddrByte
+
+ if promiscuous {
+ if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ }
+
+ if spoofing {
+ if err := s.SetSpoofing(nicid, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ }
+
+ // 1. No Address yet, send should only work for spoofing, receive for
+ // promiscuous mode.
+ //-----------------------
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 2. Add Address, everything should work.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 3. Remove the address, send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 4. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 5. Take a reference to the endpoint by getting a route. Verify that
+ // we can still send/receive, including sending using the route.
+ //-----------------------
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // 6. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ testSend(t, r, linkEP, nil)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 7. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // 8. Remove the route, sendTo/recv should still work.
+ //-----------------------
+ r.Release()
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 9. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+ })
+ }
}
}
@@ -583,9 +851,13 @@ func TestPromiscuousMode(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -593,22 +865,15 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
@@ -621,54 +886,120 @@ func TestPromiscuousMode(t *testing.T) {
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+func TestSpoofingWithAddress(t *testing.T) {
+ localAddr := tcpip.Address("\x01")
+ nonExistentLocalAddr := tcpip.Address("\x02")
+ dstAddr := tcpip.Address("\x03")
+
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ testSendTo(t, s, dstAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // FindRoute should also work with a local address that exists on the NIC.
+ r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != localAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet using the route works.
+ testSend(t, r, linkEP, nil)
}
-func TestAddressSpoofing(t *testing.T) {
- srcAddr := tcpip.Address("\x01")
+func TestSpoofingNoAddress(t *testing.T) {
+ nonExistentLocalAddr := tcpip.Address("\x01")
dstAddr := tcpip.Address("\x02")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
+ id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
+ // Sending a packet fails.
+ testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute)
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
if err := s.SetSpoofing(1, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- if r.LocalAddress != srcAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
+ // Sending a packet works.
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
@@ -806,16 +1137,20 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -824,9 +1159,52 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
t.Fatal("AddSubnet failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+}
+
+// Set the subnet, then check that CheckLocalAddress returns the correct NIC.
+func TestCheckLocalAddressForSubnet(t *testing.T) {
+ const nicID tcpip.NICID = 1
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}})
+ }
+
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
+
+ if err != nil {
+ t.Fatal("NewSubnet failed:", err)
+ }
+ if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddSubnet failed:", err)
+ }
+
+ // Loop over all subnet addresses and check them.
+ numOfAddresses := 1 << uint(8-subnet.Prefix())
+ if numOfAddresses < 1 || numOfAddresses > 255 {
+ t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
+ }
+ addr := []byte(subnet.ID())
+ for i := 0; i < numOfAddresses; i++ {
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID)
+ }
+ addr[0]++
+ }
+
+ // Trying the next address should fail since it is outside the subnet range.
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0)
}
}
@@ -839,16 +1217,20 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -856,10 +1238,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatal("AddSubnet failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
}
func TestNetworkOptions(t *testing.T) {
@@ -1213,15 +1592,19 @@ func TestNICStats(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ t.Fatal("CreateNIC failed: ", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
}
// Route all packets for address \x01 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\xff", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
// Send a packet to address 1.
buf := buffer.NewView(30)
@@ -1236,7 +1619,9 @@ func TestNICStats(t *testing.T) {
payload := buffer.NewView(10)
// Write a packet out via the address for NIC 1
- sendTo(t, s, "\x01", payload)
+ if err := sendTo(s, "\x01", payload); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
want := uint64(linkEP1.Drain())
if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
@@ -1270,9 +1655,13 @@ func TestNICForwarding(t *testing.T) {
}
// Route all packets to address 3 to NIC 2.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
+ }
// Send a packet to address 3.
buf := buffer.NewView(30)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index eee3144cd..5335897f5 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -65,7 +65,7 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
@@ -79,10 +79,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions)
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
-func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -105,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
@@ -279,7 +284,13 @@ func TestTransportReceive(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
@@ -335,7 +346,13 @@ func TestTransportControlReceive(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
@@ -401,7 +418,13 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("AddAddress failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
// Create endpoint and bind it.
wq := waiter.Queue{}
@@ -492,10 +515,20 @@ func TestTransportForwarding(t *testing.T) {
// Route all packets to address 3 to NIC 2 and all packets to address
// 1 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- {"\x01", "\xff", "\x00", 1},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ })
+ }
wq := waiter.Queue{}
ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 119712d2f..8f9b86cce 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "math/bits"
"reflect"
"strconv"
"strings"
@@ -145,8 +146,17 @@ type Address string
type AddressMask string
// String implements Stringer.
-func (a AddressMask) String() string {
- return Address(a).String()
+func (m AddressMask) String() string {
+ return Address(m).String()
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (m AddressMask) Prefix() int {
+ p := 0
+ for _, b := range []byte(m) {
+ p += bits.LeadingZeros8(^b)
+ }
+ return p
}
// Subnet is a subnet defined by its address and mask.
@@ -168,6 +178,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) {
return Subnet{a, m}, nil
}
+// String implements Stringer.
+func (s Subnet) String() string {
+ return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
+}
+
// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {
@@ -190,28 +205,13 @@ func (s *Subnet) ID() Address {
// Bits returns the number of ones (network bits) and zeros (host bits) in the
// subnet mask.
func (s *Subnet) Bits() (ones int, zeros int) {
- for _, b := range []byte(s.mask) {
- for i := uint(0); i < 8; i++ {
- if b&(1<<i) == 0 {
- zeros++
- } else {
- ones++
- }
- }
- }
- return
+ ones = s.mask.Prefix()
+ return ones, len(s.mask)*8 - ones
}
// Prefix returns the number of bits before the first host bit.
func (s *Subnet) Prefix() int {
- for i, b := range []byte(s.mask) {
- for j := 7; j >= 0; j-- {
- if b&(1<<uint(j)) == 0 {
- return i*8 + 7 - j
- }
- }
- }
- return len(s.mask) * 8
+ return s.mask.Prefix()
}
// Mask returns the subnet mask.
@@ -329,12 +329,12 @@ type Endpoint interface {
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
- Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error)
+ Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek([][]byte) (uintptr, ControlMessages, *Error)
+ Peek([][]byte) (int64, ControlMessages, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -353,6 +353,9 @@ type Endpoint interface {
// ErrAddressFamilyNotSupported must be returned.
Connect(address FullAddress) *Error
+ // Disconnect disconnects the endpoint from its peer.
+ Disconnect() *Error
+
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
Shutdown(flags ShutdownFlags) *Error
@@ -567,13 +570,8 @@ type BroadcastOption int
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
type Route struct {
- // Destination is the address that must be matched against the masked
- // target address to check if this row is viable.
- Destination Address
-
- // Mask specifies which bits of the Destination and the target address
- // must match for this row to be viable.
- Mask AddressMask
+ // Destination must contain the target address for this row to be viable.
+ Destination Subnet
// Gateway is the gateway to be used if this row is viable.
Gateway Address
@@ -582,25 +580,15 @@ type Route struct {
NIC NICID
}
-// Match determines if r is viable for the given destination address.
-func (r *Route) Match(addr Address) bool {
- if len(addr) != len(r.Destination) {
- return false
- }
-
- // Using header.Ipv4Broadcast would introduce an import cycle, so
- // we'll use a literal instead.
- if addr == "\xff\xff\xff\xff" {
- return true
- }
-
- for i := 0; i < len(r.Destination); i++ {
- if (addr[i] & r.Mask[i]) != r.Destination[i] {
- return false
- }
+// String implements the fmt.Stringer interface.
+func (r Route) String() string {
+ var out strings.Builder
+ fmt.Fprintf(&out, "%s", r.Destination)
+ if len(r.Gateway) > 0 {
+ fmt.Fprintf(&out, " via %s", r.Gateway)
}
-
- return true
+ fmt.Fprintf(&out, " nic %d", r.NIC)
+ return out.String()
}
// LinkEndpointID represents a data link layer endpoint.
@@ -1072,6 +1060,11 @@ type AddressWithPrefix struct {
PrefixLen int
}
+// String implements the fmt.Stringer interface.
+func (a AddressWithPrefix) String() string {
+ return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen)
+}
+
// ProtocolAddress is an address and the network protocol it is associated
// with.
type ProtocolAddress struct {
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index ebb1c1b56..fb3a0a5ee 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -60,12 +60,12 @@ func TestSubnetBits(t *testing.T) {
}{
{"\x00", 0, 8},
{"\x00\x00", 0, 16},
- {"\x36", 4, 4},
- {"\x5c", 4, 4},
- {"\x5c\x5c", 8, 8},
- {"\x5c\x36", 8, 8},
- {"\x36\x5c", 8, 8},
- {"\x36\x36", 8, 8},
+ {"\x36", 0, 8},
+ {"\x5c", 0, 8},
+ {"\x5c\x5c", 0, 16},
+ {"\x5c\x36", 0, 16},
+ {"\x36\x5c", 0, 16},
+ {"\x36\x36", 0, 16},
{"\xff", 8, 0},
{"\xff\xff", 16, 0},
}
@@ -122,26 +122,6 @@ func TestSubnetCreation(t *testing.T) {
}
}
-func TestRouteMatch(t *testing.T) {
- tests := []struct {
- d Address
- m AddressMask
- a Address
- want bool
- }{
- {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
- {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
- {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
- {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
- }
- for _, tt := range tests {
- r := Route{Destination: tt.d, Mask: tt.m}
- if got := r.Match(tt.a); got != tt.want {
- t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want)
- }
- }
-}
-
func TestAddressString(t *testing.T) {
for _, want := range []string{
// Taken from stdlib.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 9a4306011..451d3880e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -136,34 +136,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
return e.stack.IPTables(), nil
}
-// Resume implements tcpip.ResumableEndpoint.Resume.
-func (e *endpoint) Resume(s *stack.Stack) {
- e.stack = s
-
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- var err *tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
- if err != nil {
- panic(*err)
- }
-
- e.id.LocalAddress = e.route.LocalAddress
- } else if len(e.id.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
- }
- }
-
- e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
- if err != nil {
- panic(*err)
- }
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -233,7 +205,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -335,11 +307,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -456,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 43551d642..c587b96b6 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,6 +15,7 @@
package icmp
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -64,3 +65,31 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
+ if err != nil {
+ panic(err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index eab3dcbd2..13e17e2a6 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -174,31 +174,6 @@ func (ep *endpoint) IPTables() (iptables.IPTables, error) {
return ep.stack.IPTables(), nil
}
-// Resume implements tcpip.ResumableEndpoint.Resume.
-func (ep *endpoint) Resume(s *stack.Stack) {
- ep.stack = s
-
- // If the endpoint is connected, re-connect.
- if ep.connected {
- var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
- if err != nil {
- panic(*err)
- }
- }
-
- // If the endpoint is bound, re-bind.
- if ep.bound {
- if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
- panic(tcpip.ErrBadLocalAddress)
- }
- }
-
- if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
- panic(*err)
- }
-}
-
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
if !ep.associated {
@@ -232,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -336,7 +311,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
// We may need to resolve the route (match a link layer address to the
// network address). If that requires blocking (e.g. to use ARP),
// return a channel on which the caller can wait.
@@ -366,24 +341,24 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintpt
return 0, nil, tcpip.ErrUnknownProtocol
}
- return uintptr(len(payloadBytes)), nil, nil
+ return int64(len(payloadBytes)), nil, nil
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect implements tcpip.Endpoint.Connect.
func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
if ep.closed {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 44abddb2b..168953dec 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,6 +15,7 @@
package raw
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -64,3 +65,28 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) {
func (ep *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(ep)
}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (ep *endpoint) Resume(s *stack.Stack) {
+ ep.stack = s
+
+ // If the endpoint is connected, re-connect.
+ if ep.connected {
+ var err *tcpip.Error
+ ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ if err != nil {
+ panic(err)
+ }
+ }
+
+ // If the endpoint is bound, re-bind.
+ if ep.bound {
+ if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index e67169111..ac927569a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -720,107 +720,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
return e.stack.IPTables(), nil
}
-// Resume implements tcpip.ResumableEndpoint.Resume.
-func (e *endpoint) Resume(s *stack.Stack) {
- e.stack = s
- e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
-
- state := e.state
- switch state {
- case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
- var ss SendBufferSizeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
- }
- if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
- }
- }
- }
-
- bind := func() {
- e.state = StateInitial
- if len(e.bindAddress) == 0 {
- e.bindAddress = e.id.LocalAddress
- }
- if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
- panic("endpoint binding failed: " + err.String())
- }
- }
-
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
- bind()
- if len(e.connectingAddress) == 0 {
- e.connectingAddress = e.id.RemoteAddress
- // This endpoint is accepted by netstack but not yet by
- // the app. If the endpoint is IPv6 but the remote
- // address is IPv4, we need to connect as IPv6 so that
- // dual-stack mode can be properly activated.
- if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
- e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
- }
- }
- // Reset the scoreboard to reinitialize the sack information as
- // we do not restore SACK information.
- e.scoreboard.Reset()
- if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
- panic("endpoint connecting failed: " + err.String())
- }
- connectedLoading.Done()
- case StateListen:
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- bind()
- backlog := cap(e.acceptedChan)
- if err := e.Listen(backlog); err != nil {
- panic("endpoint listening failed: " + err.String())
- }
- listenLoading.Done()
- tcpip.AsyncLoading.Done()
- }()
- case StateConnecting, StateSynSent, StateSynRecv:
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- bind()
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
- panic("endpoint connecting failed: " + err.String())
- }
- connectingLoading.Done()
- tcpip.AsyncLoading.Done()
- }()
- case StateBound:
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- tcpip.AsyncLoading.Done()
- }()
- case StateClose:
- if e.isPortReserved {
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- e.state = StateClose
- tcpip.AsyncLoading.Done()
- }()
- }
- fallthrough
- case StateError:
- tcpip.DeleteDanglingEndpoint(e)
- }
-}
-
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
@@ -878,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
return v, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
-
- e.mu.RLock()
- defer e.mu.RUnlock()
-
+// isEndpointWritableLocked checks if a given endpoint is writable
+// and also returns the number of bytes that can be written at this
+// moment. If the endpoint is not writable then it returns an error
+// indicating the reason why it's not writable.
+// Caller must hold e.mu and e.sndBufMu
+func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, nil, e.hardError
+ return 0, e.hardError
default:
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
}
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
- }
-
- e.sndBufMu.Lock()
-
// Check if the connection has already been closed for sends.
if e.sndClosed {
- e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
- // Check against the limit.
avail := e.sndBufSize - e.sndBufUsed
if avail <= 0 {
+ return 0, tcpip.ErrWouldBlock
+ }
+ return avail, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrWouldBlock
+ e.mu.RUnlock()
+ return 0, nil, err
}
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
+ }
+
+ // Copy in memory without holding sndBufMu so that worker goroutine can
+ // make progress independent of this operation.
v, perr := p.Get(avail)
if perr != nil {
- e.sndBufMu.Unlock()
return 0, nil, perr
}
- l := len(v)
- s := newSegmentFromView(&e.route, e.id, v)
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a
+ // write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due to a
+ // simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
// Add data to the send queue.
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
e.sndBufUsed += l
e.sndBufInQueue += seqnum.Size(l)
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
if e.workMu.TryLock() {
// Do the work inline.
@@ -941,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return uintptr(l), nil, nil
+ return int64(l), nil, nil
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -973,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
// Make a copy of vec so we can modify the slide headers.
vec = append([][]byte(nil), vec...)
- var num uintptr
-
+ var num int64
for s := e.rcvList.Front(); s != nil; s = s.Next() {
views := s.data.Views()
@@ -993,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
n := copy(vec[0], v)
v = v[n:]
vec[0] = vec[0][n:]
- num += uintptr(n)
+ num += int64(n)
}
}
}
@@ -1415,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == "\x00\x00\x00\x00" {
+ if addr.Addr == header.IPv4Any {
addr.Addr = ""
}
}
@@ -1429,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" && addr.Port == 0 {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
return e.connect(addr, true, true)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index ef88dc618..831389ec7 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -20,6 +20,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -167,6 +168,107 @@ func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+
+ state := e.state
+ switch state {
+ case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ }
+ if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+ }
+ }
+
+ bind := func() {
+ e.state = StateInitial
+ if len(e.bindAddress) == 0 {
+ e.bindAddress = e.id.LocalAddress
+ }
+ if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ bind()
+ if len(e.connectingAddress) == 0 {
+ e.connectingAddress = e.id.RemoteAddress
+ // This endpoint is accepted by netstack but not yet by
+ // the app. If the endpoint is IPv6 but the remote
+ // address is IPv4, we need to connect as IPv6 so that
+ // dual-stack mode can be properly activated.
+ if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ }
+ }
+ // Reset the scoreboard to reinitialize the sack information as
+ // we do not restore SACK information.
+ e.scoreboard.Reset()
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectedLoading.Done()
+ case StateListen:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ bind()
+ backlog := cap(e.acceptedChan)
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ listenLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case StateConnecting, StateSynSent, StateSynRecv:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ bind()
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectingLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case StateBound:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ tcpip.AsyncLoading.Done()
+ }()
+ case StateClose:
+ if e.isPortReserved {
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ e.state = StateClose
+ tcpip.AsyncLoading.Done()
+ }()
+ }
+ fallthrough
+ case StateError:
+ tcpip.DeleteDanglingEndpoint(e)
+ }
+}
+
// saveLastError is invoked by stateify.
func (e *endpoint) saveLastError() string {
if e.lastError == nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 915a98047..f79b8ec5f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2874,15 +2874,11 @@ func makeStack() (*stack.Stack, *tcpip.Error) {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index bcc0f3e28..272481aa0 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -168,15 +168,11 @@ func New(t *testing.T, mtu uint32) *Context {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 7c12a6092..ac5905772 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -178,53 +178,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) {
return e.stack.IPTables(), nil
}
-// Resume implements tcpip.ResumableEndpoint.Resume.
-func (e *endpoint) Resume(s *stack.Stack) {
- e.stack = s
-
- for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
-
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- netProto := e.effectiveNetProtos[0]
- // Connect() and bindLocked() both assert
- //
- // netProto == header.IPv6ProtocolNumber
- //
- // before creating a multi-entry effectiveNetProtos.
- if len(e.effectiveNetProtos) > 1 {
- netProto = header.IPv6ProtocolNumber
- }
-
- var err *tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
- if err != nil {
- panic(*err)
- }
- } else if len(e.id.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
- }
- }
-
- // Our saved state had a port, but we don't actually have a
- // reservation. We need to remove the port from our state, but still
- // pass it to the reservation machinery.
- id := e.id
- e.id.LocalPort = 0
- e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
- if err != nil {
- panic(*err)
- }
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -296,6 +249,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// specified address is a multicast address.
func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.id.LocalAddress
+ if isBroadcastOrMulticast(localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
if nicid == 0 {
nicid = e.multicastNICID
@@ -315,7 +273,7 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -421,11 +379,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -495,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
nicID := v.NIC
- if v.InterfaceAddr == header.IPv4Any {
+
+ // The interface address is considered not-set if it is empty or contains
+ // all-zeros. The former represent the zero-value in golang, the latter the
+ // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
+ allZeros := header.IPv4Any
+ if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
if nicID == 0 {
r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
if err == nil {
@@ -739,7 +702,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == "\x00\x00\x00\x00" {
+ if addr.Addr == header.IPv4Any {
addr.Addr = ""
}
@@ -758,7 +721,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
-func (e *endpoint) disconnect() *tcpip.Error {
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -797,9 +761,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- if addr.Addr == "" {
- return e.disconnect()
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -963,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
nicid := addr.NIC
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
+ if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ // A local unicast address was specified, verify that it's valid.
nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicid == 0 {
return tcpip.ErrBadLocalAddress
@@ -1113,3 +1074,7 @@ func (e *endpoint) State() uint32 {
// TODO(b/112063468): Translate internal state to values returned by Linux.
return 0
}
+
+func isBroadcastOrMulticast(a tcpip.Address) bool {
+ return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 86db36260..5cbb56120 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,7 +15,9 @@
package udp
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -64,3 +66,51 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
func (e *endpoint) afterLoad() {
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
+
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ if err != nil {
+ panic(err)
+ }
+ } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound
+ // A local unicast address is specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ id := e.id
+ e.id.LocalPort = 0
+ e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 56c285f88..9da6edce2 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "fmt"
"math"
"math/rand"
"testing"
@@ -34,13 +35,19 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// Addresses and ports used for testing. It is recommended that tests stick to
+// using these addresses as it allows using the testFlow helper.
+// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
+// represents the remote endpoint.
const (
+ v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
+ testV4MappedAddr = v4MappedAddrPrefix + testAddr
+ multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
+ broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
+ v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
@@ -48,7 +55,7 @@ const (
testPort = 4096
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- multicastPort = 1234
+ broadcastAddr = header.IPv4Broadcast
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -56,6 +63,205 @@ const (
defaultMTU = 65536
)
+// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
+// a packet header. These values are used to populate a header or verify one.
+// Note that because they are used in packet headers, the addresses are never in
+// a V4-mapped format.
+type header4Tuple struct {
+ srcAddr tcpip.FullAddress
+ dstAddr tcpip.FullAddress
+}
+
+// testFlow implements a helper type used for sending and receiving test
+// packets. A given test flow value defines 1) the socket endpoint used for the
+// test and 2) the type of packet send or received on the endpoint. E.g., a
+// multicastV6Only flow is a V6 multicast packet passing through a V6-only
+// endpoint. The type provides helper methods to characterize the flow (e.g.,
+// isV4) as well as return a proper header4Tuple for it.
+type testFlow int
+
+const (
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+)
+
+func (flow testFlow) String() string {
+ switch flow {
+ case unicastV4:
+ return "unicastV4"
+ case unicastV6:
+ return "unicastV6"
+ case unicastV6Only:
+ return "unicastV6Only"
+ case unicastV4in6:
+ return "unicastV4in6"
+ case multicastV4:
+ return "multicastV4"
+ case multicastV6:
+ return "multicastV6"
+ case multicastV6Only:
+ return "multicastV6Only"
+ case multicastV4in6:
+ return "multicastV4in6"
+ case broadcast:
+ return "broadcast"
+ case broadcastIn6:
+ return "broadcastIn6"
+ default:
+ return "unknown"
+ }
+}
+
+// packetDirection explains if a flow is incoming (read) or outgoing (write).
+type packetDirection int
+
+const (
+ incoming packetDirection = iota
+ outgoing
+)
+
+// header4Tuple returns the header4Tuple for the given flow and direction. Note
+// that the tuple contains no mapped addresses as those only exist at the socket
+// level but not at the packet header level.
+func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
+ var h header4Tuple
+ if flow.isV4() {
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastAddr
+ } else if flow.isBroadcast() {
+ h.dstAddr.Addr = broadcastAddr
+ }
+ } else { // IPv6
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastV6Addr
+ }
+ }
+ return h
+}
+
+func (flow testFlow) getMcastAddr() tcpip.Address {
+ if flow.isV4() {
+ return multicastAddr
+ }
+ return multicastV6Addr
+}
+
+// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
+// if it is applicable to the flow.
+func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
+ if flow.isMapped() {
+ return v4MappedAddrPrefix + v4Addr
+ }
+ return v4Addr
+}
+
+// netProto returns the protocol number used for the network packet.
+func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
+ if flow.isV4() {
+ return ipv4.ProtocolNumber
+ }
+ return ipv6.ProtocolNumber
+}
+
+// sockProto returns the protocol number used when creating the socket
+// endpoint for this flow.
+func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
+ switch flow {
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ return ipv6.ProtocolNumber
+ case unicastV4, multicastV4, broadcast:
+ return ipv4.ProtocolNumber
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
+ if flow.isV4() {
+ return checker.IPv4
+ }
+ return checker.IPv6
+}
+
+func (flow testFlow) isV6() bool { return !flow.isV4() }
+func (flow testFlow) isV4() bool {
+ return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
+}
+
+func (flow testFlow) isV6Only() bool {
+ switch flow {
+ case unicastV6Only, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMulticast() bool {
+ switch flow {
+ case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isBroadcast() bool {
+ switch flow {
+ case broadcast, broadcastIn6:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMapped() bool {
+ switch flow {
+ case unicastV4in6, multicastV4in6, broadcastIn6:
+ return true
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -65,12 +271,9 @@ type testContext struct {
wq waiter.Queue
}
-type headers struct {
- srcPort uint16
- dstPort uint16
-}
-
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
+ t.Helper()
+
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(256, mtu, "")
@@ -91,15 +294,11 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
@@ -117,51 +316,54 @@ func (c *testContext) cleanup() {
}
}
-func (c *testContext) createV6Endpoint(v6only bool) {
+func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
+ c.t.Helper()
+
var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
+ c.t.Fatal("NewEndpoint failed: ", err)
}
+}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
+func (c *testContext) createEndpointForFlow(flow testFlow) {
+ c.t.Helper()
+
+ c.createEndpoint(flow.sockProto())
+ if flow.isV6Only() {
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ } else if flow.isBroadcast() {
+ if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
}
}
-func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
+// getPacketAndVerify reads a packet from the link endpoint and verifies the
+// header against expected values from the given test flow. In addition, it
+// calls any extra checker functions provided.
+func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
+ c.t.Helper()
+
select {
case p := <-c.linkEP.C:
- if p.Proto != protocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
- var srcAddr, dstAddr tcpip.Address
- switch protocolNumber {
- case ipv4.ProtocolNumber:
- checkerFn = checker.IPv4
- srcAddr, dstAddr = stackAddr, testAddr
- if multicast {
- dstAddr = multicastAddr
- }
- case ipv6.ProtocolNumber:
- checkerFn = checker.IPv6
- srcAddr, dstAddr = stackV6Addr, testV6Addr
- if multicast {
- dstAddr = multicastV6Addr
- }
- default:
- c.t.Fatalf("unknown protocol %d", protocolNumber)
- }
- checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
+ h := flow.header4Tuple(outgoing)
+ checkers := append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
return b
case <-time.After(2 * time.Second):
@@ -171,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult
return nil
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers) {
+// injectPacket creates a packet of the given flow and with the given payload,
+// and injects it into the link endpoint.
+func (c *testContext) injectPacket(flow testFlow, payload []byte) {
+ c.t.Helper()
+
+ h := flow.header4Tuple(incoming)
+ if flow.isV4() {
+ c.injectV4Packet(payload, &h)
+ } else {
+ c.injectV6Packet(payload, &h)
+ }
+}
+
+// injectV6Packet creates a V6 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -182,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -205,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-func (c *testContext) sendPacket(payload []byte, h *headers) {
+// injectV6Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -217,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) {
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: testAddr,
- DstAddr: stackAddr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -253,7 +472,7 @@ func TestBindPortReuse(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
var eps [5]tcpip.Endpoint
reusePortOpt := tcpip.ReusePortOption(1)
@@ -296,9 +515,9 @@ func TestBindPortReuse(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort + port,
- dstPort: stackPort,
+ c.injectV6Packet(payload, &header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
})
var addr tcpip.FullAddress
@@ -333,13 +552,14 @@ func TestBindPortReuse(t *testing.T) {
}
}
-func testV4Read(c *testContext) {
- // Send a packet.
+// testRead sends a packet of the given test flow into the stack by injecting it
+// into the link endpoint. It then reads it from the UDP endpoint and verifies
+// its correctness.
+func testRead(c *testContext, flow testFlow) {
+ c.t.Helper()
+
payload := newPayload()
- c.sendPacket(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
+ c.injectPacket(flow, payload)
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
@@ -363,8 +583,9 @@ func testV4Read(c *testContext) {
}
// Check the peer address.
- if addr.Addr != testAddr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ h := flow.header4Tuple(incoming)
+ if addr.Addr != h.srcAddr.Addr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr)
}
// Check the payload.
@@ -377,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
t.Fatalf("ep.Bind(...) failed: %v", err)
@@ -388,7 +609,7 @@ func TestBindReservedPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
@@ -447,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -455,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil {
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -485,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV6ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- // Send a packet.
- payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.EventIn)
- defer c.wq.EventUnregister(&we)
-
- var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- // Wait for data to become available.
- select {
- case <-ch:
- v, _, err = c.ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for data")
- }
- }
-
- // Check the peer address.
- if addr.Addr != testV6Addr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
- }
-
- // Check the payload.
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
- }
+ // Test acceptance.
+ testRead(c, unicastV6)
}
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- // Create v4 UDP endpoint.
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpointForFlow(unicastV4)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -555,62 +736,123 @@ func TestV4ReadOnV4(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4)
}
-func testV4Write(c *testContext) uint16 {
- // Write to V4 mapped address.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
+// address and receive data sent to that address.
+func TestReadOnBoundToMulticast(t *testing.T) {
+ // FIXME(b/128189410): multicastV4in6 currently doesn't work as
+ // AddMembershipOption doesn't handle V4in6 addresses.
+ for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to multicast address.
+ mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ // Join multicast group.
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
+
+ testRead(c, flow)
+ })
}
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
+// address and receive broadcast data on it.
+func TestV4ReadOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to broadcast address.
+ bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// testFailingWrite sends a packet of the given test flow into the UDP endpoint
+// and verifies it fails with the provided error code.
+func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
+ c.t.Helper()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+
+ payload := buffer.View(newPayload())
+ _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ })
+ if gotErr != wantErr {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
+}
- return udp.SourcePort()
+// testWrite sends a packet of the given test flow from the UDP endpoint to the
+// flow's destination address:port. It then receives it from the link endpoint
+// and verifies its correctness including any additional checker functions
+// provided.
+func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, true, checkers...)
}
-func testV6Write(c *testContext) uint16 {
- // Write to v6 address.
+// testWriteWithoutDestination sends a packet of the given test flow from the
+// UDP endpoint without giving a destination address:port. It then receives it
+// from the link endpoint and verifies its correctness including any additional
+// checker functions provided.
+func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, false, checkers...)
+}
+
+func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+
+ writeOpts := tcpip.WriteOptions{}
+ if setDest {
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+ writeOpts = tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ }
+ }
payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
- if n != uintptr(len(payload)) {
+ if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
+ // Received the packet and check the payload.
+ b := c.getPacketAndVerify(flow, checkers...)
+ var udp header.UDP
+ if flow.isV4() {
+ udp = header.UDP(header.IPv4(b).Payload())
+ } else {
+ udp = header.UDP(header.IPv6(b).Payload())
+ }
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
@@ -619,8 +861,10 @@ func testV6Write(c *testContext) uint16 {
}
func testDualWrite(c *testContext) uint16 {
- v4Port := testV4Write(c)
- v6Port := testV6Write(c)
+ c.t.Helper()
+
+ v4Port := testWrite(c, unicastV4in6)
+ v6Port := testWrite(c, unicastV6)
if v4Port != v6Port {
c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
}
@@ -632,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
}
@@ -641,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -658,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV6Write(c)
+ testWrite(c, unicastV6)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNetworkUnreachable {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Write(c)
+ testWrite(c, unicastV4in6)
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV4WriteOnV6Only(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(true)
+ c.createEndpointForFlow(unicastV6Only)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNoRoute {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -728,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
}
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV6WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
-
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
- }
+ testWriteWithoutDestination(c, unicastV6)
}
func TestV4WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ testWriteWithoutDestination(c, unicastV4)
+}
+
+// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
+// that is bound to a V4 multicast address.
+func TestWriteOnBoundToV4Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
}
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
+// socket that is bound to a V4-mapped multicast address.
+func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6, multicastV6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// V6-only socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToBroadcast checks that we can send packets out of a
+// socket that is bound to the broadcast address.
+func TestWriteOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 broadcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
+// socket that is bound to the V4-mapped broadcast address.
+func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
}
@@ -814,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
defer c.cleanup()
// Create IPv4 UDP endpoint
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Read(c)
+ testRead(c, unicastV4)
var want uint64 = 1
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
@@ -837,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
@@ -847,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
-func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) {
- for _, name := range []string{"v4", "v6", "dual"} {
- t.Run(name, func(t *testing.T) {
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch name {
- case "v4":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6", "dual":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+func TestTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- var variants []string
- switch name {
- case "v4":
- variants = []string{"v4"}
- case "v6":
- variants = []string{"v6"}
- case "dual":
- variants = []string{"v6", "mapped"}
- }
+ c.createEndpointForFlow(flow)
- for _, variant := range variants {
- t.Run(variant, func(t *testing.T) {
- optFunc(t, name, networkProtocolNumber, variant)
- })
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
}
- })
- }
-}
-func TestTTL(t *testing.T) {
- payload := tcpip.SlicePayload(buffer.View(newPayload()))
-
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, typ := range []string{"unicast", "multicast"} {
- t.Run(typ, func(t *testing.T) {
- var addr tcpip.Address
- var port uint16
- switch typ {
- case "unicast":
- port = testPort
- switch variant {
- case "v4":
- addr = testAddr
- case "mapped":
- addr = testV4MappedAddr
- case "v6":
- addr = testV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- port = multicastPort
- switch variant {
- case "v4":
- addr = multicastAddr
- case "mapped":
- addr = multicastV4MappedAddr
- case "v6":
- addr = multicastV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- default:
- t.Fatal("unknown test variant")
+ var wantTTL uint8
+ if flow.isMulticast() {
+ wantTTL = multicastTTL
+ } else {
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
}
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- switch name {
- case "v4":
- case "v6":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- case "dual":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- default:
- t.Fatal("unknown test variant")
- }
-
- const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ t.Fatal(err)
}
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ }
- n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
- }
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+}
- checkerFn := checker.IPv4
- switch variant {
- case "v4", "mapped":
- case "v6":
- checkerFn = checker.IPv6
- default:
- t.Fatal("unknown test variant")
- }
- var wantTTL uint8
- var multicast bool
- switch typ {
- case "unicast":
- multicast = false
- switch variant {
- case "v4", "mapped":
- ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- case "v6":
- ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- wantTTL = multicastTTL
- multicast = true
- default:
- t.Fatal("unknown test variant")
- }
+func TestMulticastInterfaceOption(t *testing.T) {
+ for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, bindTyp := range []string{"bound", "unbound"} {
+ t.Run(bindTyp, func(t *testing.T) {
+ for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
+ t.Run(optTyp, func(t *testing.T) {
+ h := flow.header4Tuple(outgoing)
+ mcastAddr := h.dstAddr.Addr
+ localIfAddr := h.srcAddr.Addr
+
+ var ifoptSet tcpip.MulticastInterfaceOption
+ switch optTyp {
+ case "use local-addr":
+ ifoptSet.InterfaceAddr = localIfAddr
+ case "use NICID":
+ ifoptSet.NIC = 1
+ case "use local-addr and NIC":
+ ifoptSet.InterfaceAddr = localIfAddr
+ ifoptSet.NIC = 1
+ default:
+ t.Fatal("unknown test variant")
+ }
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch variant {
- case "v4", "mapped":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(flow.sockProto())
+
+ if bindTyp == "bound" {
+ // Bind the socket by connecting to the multicast address.
+ // This may have an influence on how the multicast interface
+ // is set.
+ addr := tcpip.FullAddress{
+ Addr: flow.mapAddrIfApplicable(mcastAddr),
+ Port: stackPort,
+ }
+ if err := c.ep.Connect(addr); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+ }
- b := c.getPacket(networkProtocolNumber, multicast)
- checkerFn(c.t, b,
- checker.TTL(wantTTL),
- checker.UDP(
- checker.DstPort(port),
- ),
- )
- })
- }
- })
-}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
-func TestMulticastInterfaceOption(t *testing.T) {
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- var mcastAddr, localIfAddr tcpip.Address
- switch variant {
- case "v4":
- mcastAddr = multicastAddr
- localIfAddr = stackAddr
- case "mapped":
- mcastAddr = multicastV4MappedAddr
- localIfAddr = stackAddr
- case "v6":
- mcastAddr = multicastV6Addr
- localIfAddr = stackV6Addr
- default:
- t.Fatal("unknown test variant")
- }
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: mcastAddr,
- Port: multicastPort,
+ // Verify multicast interface addr and NIC were set correctly.
+ // Note that NIC must be 1 since this is our outgoing interface.
+ ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
+ var ifoptGot tcpip.MulticastInterfaceOption
+ if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
+ c.t.Fatalf("GetSockOpt failed: %v", err)
}
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ if ifoptGot != ifoptWant {
+ c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
}
- }
-
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
+ })
+ }
+ })
+ }
+ })
+ }
}