diff options
Diffstat (limited to 'pkg/tcpip')
36 files changed, 1977 insertions, 1185 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index c40924852..0d2637ee4 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -24,6 +24,7 @@ go_test( embed = [":gonet"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 308f620e5..cd6ce930a 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -404,7 +404,7 @@ func (c *Conn) Write(b []byte) (int, error) { } } - var n uintptr + var n int64 var resCh <-chan struct{} n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) nbytes += int(n) @@ -556,32 +556,50 @@ type PacketConn struct { wq *waiter.Queue } -// NewPacketConn creates a new PacketConn. -func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { - // Create UDP endpoint and bind it. +// DialUDP creates a new PacketConn. +// +// If laddr is nil, a local address is automatically chosen. +// +// If raddr is nil, the PacketConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) if err != nil { return nil, errors.New(err.String()) } - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "udp", - Addr: fullToUDPAddr(addr), - Err: errors.New(err.String()), + if laddr != nil { + if err := ep.Bind(*laddr); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "udp", + Addr: fullToUDPAddr(*laddr), + Err: errors.New(err.String()), + } } } - c := &PacketConn{ + c := PacketConn{ stack: s, ep: ep, wq: &wq, } c.deadlineTimer.init() - return c, nil + + if raddr != nil { + if err := c.ep.Connect(*raddr); err != nil { + c.ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "udp", + Addr: fullToUDPAddr(*raddr), + Err: errors.New(err.String()), + } + } + } + + return &c, nil } func (c *PacketConn) newOpError(op string, err error) *net.OpError { diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 39efe44c7..672f026b2 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -26,6 +26,7 @@ import ( "golang.org/x/net/nettest" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -69,17 +70,13 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) { s.SetRouteTable([]tcpip.Route{ // IPv4 { - Destination: tcpip.Address(strings.Repeat("\x00", 4)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 4)), - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: NICID, }, // IPv6 { - Destination: tcpip.Address(strings.Repeat("\x00", 16)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 16)), - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: NICID, }, }) @@ -371,9 +368,9 @@ func TestUDPForwarder(t *testing.T) { }) s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } sent := "abc123" @@ -452,13 +449,13 @@ func TestPacketConnTransfer(t *testing.T) { addr2 := tcpip.FullAddress{NICID, ip2, 11311} s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber) + c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 4):", err) + t.Fatal("DialUDP(bind port 4):", err) } - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } c1.SetDeadline(time.Now().Add(time.Second)) @@ -491,6 +488,50 @@ func TestPacketConnTransfer(t *testing.T) { } } +func TestConnectedPacketConnTransfer(t *testing.T) { + s, e := newLoopbackStack() + if e != nil { + t.Fatalf("newLoopbackStack() = %v", e) + } + + ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr := tcpip.FullAddress{NICID, ip, 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + + c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 4):", err) + } + c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 5):", err) + } + + c1.SetDeadline(time.Now().Add(time.Second)) + c2.SetDeadline(time.Now().Add(time.Second)) + + sent := "abc123" + if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { + t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) + } + recv := make([]byte, len(sent)) + n, err := c1.Read(recv) + if err != nil || n != len(recv) { + t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) + } + + if recv := string(recv); recv != sent { + t.Errorf("got recv = %q, want = %q", recv, sent) + } + + if err := c1.Close(); err != nil { + t.Error("c1.Close():", err) + } + if err := c2.Close(); err != nil { + t.Error("c2.Close():", err) + } +} + func makePipe() (c1, c2 net.Conn, stop func(), err error) { s, e := newLoopbackStack() if e != nil { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 94a3af289..17fc9c68e 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -111,6 +111,15 @@ const ( IPv4FlagDontFragment ) +// IPv4EmptySubnet is the empty IPv4 subnet. +var IPv4EmptySubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any)) + if err != nil { + panic(err) + } + return subnet +}() + // IPVersion returns the version of IP used in the given packet. It returns -1 // if the packet is not large enough to contain the version field. func IPVersion(b []byte) int { diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 95fe8bfc3..31be42ce0 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -82,6 +82,15 @@ const ( IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ) +// IPv6EmptySubnet is the empty IPv6 subnet. +var IPv6EmptySubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) + if err != nil { + panic(err) + } + return subnet +}() + // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 11ba31ca4..088eb8a21 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -7,8 +7,9 @@ go_library( srcs = [ "blockingpoll_amd64.s", "blockingpoll_arm64.s", + "blockingpoll_noyield_unsafe.go", "blockingpoll_unsafe.go", - "blockingpoll_stub_unsafe.go", + "blockingpoll_yield_unsafe.go", "errors.go", "rawfile_unsafe.go", ], diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s index 0bc873a01..b62888b93 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s +++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/tcpip/link/rawfile/blockingpoll_stub_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go index 621ab8d29..621ab8d29 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_stub_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go index eeca47d78..84dc0e918 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go @@ -12,49 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,amd64 linux,arm64 -// +build go1.12 -// +build !go1.14 - -// Check go:linkname function signatures when updating Go version. +// +build linux,!amd64 package rawfile import ( "syscall" - _ "unsafe" // for go:linkname + "unsafe" ) -//go:noescape -func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) - -// Use go:linkname to call into the runtime. As of Go 1.12 this has to -// be done from Go code so that we make an ABIInternal call to an -// ABIInternal function; see https://golang.org/issue/27539. - -// We need to call both entersyscallblock and exitsyscall this way so -// that the runtime's check on the stack pointer lines up. - -// Note that calling an unexported function in the runtime package is -// unsafe and this hack is likely to break in future Go releases. - -//go:linkname entersyscallblock runtime.entersyscallblock -func entersyscallblock() - -//go:linkname exitsyscall runtime.exitsyscall -func exitsyscall() - -// These forwarding functions must be nosplit because 1) we must -// disallow preemption between entersyscallblock and exitsyscall, and -// 2) we have an untyped assembly frame on the stack which can not be -// grown or moved. - -//go:nosplit -func callEntersyscallblock() { - entersyscallblock() -} +// BlockingPoll is just a stub function that forwards to the ppoll() system call +// on non-amd64 platforms. +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) { + n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)), + uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0) -//go:nosplit -func callExitsyscall() { - exitsyscall() + return int(n), e } diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go new file mode 100644 index 000000000..dda3b10a6 --- /dev/null +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -0,0 +1,66 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux,amd64 linux,arm64 +// +build go1.12 +// +build !go1.14 + +// Check go:linkname function signatures when updating Go version. + +package rawfile + +import ( + "syscall" + _ "unsafe" // for go:linkname +) + +// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the +// version of entersyscall that relinquishes the P so that other Gs can +// run. This is meant to be called in cases when the syscall is expected to +// block. On non amd64/arm64 platforms it just forwards to the ppoll() system +// call. +// +//go:noescape +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) + +// Use go:linkname to call into the runtime. As of Go 1.12 this has to +// be done from Go code so that we make an ABIInternal call to an +// ABIInternal function; see https://golang.org/issue/27539. + +// We need to call both entersyscallblock and exitsyscall this way so +// that the runtime's check on the stack pointer lines up. + +// Note that calling an unexported function in the runtime package is +// unsafe and this hack is likely to break in future Go releases. + +//go:linkname entersyscallblock runtime.entersyscallblock +func entersyscallblock() + +//go:linkname exitsyscall runtime.exitsyscall +func exitsyscall() + +// These forwarding functions must be nosplit because 1) we must +// disallow preemption between entersyscallblock and exitsyscall, and +// 2) we have an untyped assembly frame on the stack which can not be +// grown or moved. + +//go:nosplit +func callEntersyscallblock() { + entersyscallblock() +} + +//go:nosplit +func callExitsyscall() { + exitsyscall() +} diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index fc584c6a4..36c8c46fc 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -360,10 +360,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() + details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) + size -= header.UDPMinimumSize } - size -= header.UDPMinimumSize - - details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) case header.TCPProtocolNumber: transName = "tcp" diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index e477046db..4c4b54469 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -66,9 +66,7 @@ func newTestContext(t *testing.T) *testContext { } s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }}) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 55e9eec99..6bbfcd97f 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -173,8 +173,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ - Destination: ipv4SubnetAddr, - Mask: ipv4SubnetMask, + Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, NIC: 1, }}) @@ -187,8 +186,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ - Destination: ipv6SubnetAddr, - Mask: ipv6SubnetMask, + Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, NIC: 1, }}) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 3207a3d46..1b5a55bea 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -52,9 +52,7 @@ func TestExcludeBroadcast(t *testing.T) { } s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }}) @@ -247,14 +245,22 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) linkEPId := stack.RegisterLinkEndpoint(linkEP) s.CreateNIC(1, linkEPId) - s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01") - s.SetRouteTable([]tcpip.Route{{ - Destination: "\x10\x00\x00\x02", - Mask: "\xff\xff\xff\xff", - Gateway: "", - NIC: 1, - }}) - r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */) + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + s.AddAddress(1, ipv4.ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("s.FindRoute got %v, want %v", err, nil) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 726362c87..d0dc72506 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -91,13 +91,18 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } } - s.SetRouteTable( - []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), - NIC: 1, - }}, - ) + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } netProto := s.NetworkProtocolInstance(ProtocolNumber) if netProto == nil { @@ -237,17 +242,23 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress sn lladdr1: %v", err) } + subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } c.s0.SetRouteTable( []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + Destination: subnet0, NIC: 1, }}, ) + subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } c.s1.SetRouteTable( []tcpip.Route{{ - Destination: lladdr0, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + Destination: subnet1, NIC: 1, }}, ) diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index 996939581..a57752a7c 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -8,6 +8,7 @@ go_binary( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3ac381631..e2021cd15 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -52,6 +52,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -152,9 +153,7 @@ func main() { // Add default route. s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index da425394a..1716be285 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -149,12 +149,15 @@ func main() { log.Fatal(err) } + subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) + if err != nil { + log.Fatal(err) + } + // Add default route. s.SetRouteTable([]tcpip.Route{ { - Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), - Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), - Gateway: "", + Destination: subnet, NIC: 1, }, }) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dc28dc970..04b63d783 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add if list, ok := n.primary[protocol]; ok { for e := list.Front(); e != nil; e = e.Next() { ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { + if ref.kind == permanent && ref.tryIncRef() { r = ref break } @@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -186,46 +186,124 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) +} + +// getRefEpOrCreateTemp returns the referenced network endpoint for the given +// protocol and address. If none exists a temporary one may be created if +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - ref := n.endpoints[id] - if ref != nil && !ref.tryIncRef() { - ref = nil + + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.kind { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } } - spoofing := n.spoofing + + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous + if !createTempEP { + for _, sn := range n.subnets { + if sn.Contains(address) { + createTempEP = true + break + } + } + } + n.mu.RUnlock() - if ref != nil || !spoofing { - return ref + if !createTempEP { + return nil } // Try again with the lock in exclusive mode. If we still can't get the // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - ref = n.endpoints[id] - if ref == nil || !ref.tryIncRef() { - if netProto, ok := n.stack.networkProtocols[protocol]; ok { - ref, _ = n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, true) - if ref != nil { - ref.holdsInsertRef = false - } + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } + + // Add a new temporary endpoint. + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.mu.Unlock() + return nil + } + ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb, temporary) + n.mu.Unlock() return ref } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if ref, ok := n.endpoints[id]; ok { + switch ref.kind { + case permanent: + // The NIC already have a permanent endpoint with that address. + return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.kind = permanent + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) + } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} + +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress + } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -236,22 +314,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } - - id := *ep.ID() - if ref, ok := n.endpoints[id]; ok { - if !replace { - return nil, tcpip.ErrDuplicateAddress - } - - n.removeEndpointLocked(ref) - } - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. @@ -284,7 +352,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err @@ -296,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.kind { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -361,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet { func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.kind == permanent { panic("Reference count dropped to zero before being removed") } @@ -386,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { + if r == nil || r.kind != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false - + r.kind = permanentExpired r.decRefLocked() return nil @@ -403,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -419,13 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint, false); err != nil { + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -447,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -489,7 +565,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr // n.endpoints is mutex protected so acquire lock. n.mu.RLock() for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { + if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remote ref.ep.HandlePacket(&r, vv) @@ -527,8 +603,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -553,57 +630,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.stats.IP.InvalidAddressesReceived.Increment() } -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - id := NetworkEndpointID{dst} - - n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - - promiscuous := n.promiscuous - // Check if the packet is for a subnet this NIC cares about. - if !promiscuous { - for _, sn := range n.subnets { - if sn.Contains(dst) { - promiscuous = true - break - } - } - } - n.mu.RUnlock() - if promiscuous { - // Try again with the lock in exclusive mode. If we still can't - // get the endpoint, create a new "temporary" one. It will only - // exist while there's a route through it. - n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref - } - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - n.mu.Unlock() - return nil - } - ref, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: dst, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, CanBePrimaryEndpoint, true) - n.mu.Unlock() - if err == nil { - ref.holdsInsertRef = false - return ref - } - } - - return nil -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { @@ -691,9 +717,33 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +type networkEndpointKind int + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -702,11 +752,25 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + kind networkEndpointKind +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.kind != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.kind != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 391ab4344..e52cdd674 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. func (r *Route) IsResolutionRequired() bool { - return r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() @@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d45e547ee..d69162ba1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -895,7 +895,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1ab9c575b..4debd1eec 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -181,6 +181,10 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } +func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { + return f.packetCount[int(intfAddr)%len(f.packetCount)] +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } @@ -289,16 +293,75 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatal("FindRoute failed:", err) + return err } defer r.Release() + return send(r, payload) +} +func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Error("WritePacket failed:", err) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) +} + +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := sendTo(s, addr, payload); err != nil { + t.Error("sendTo failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("sendTo packet count: got = %d, want %d", got, want) + } +} + +func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := send(r, payload); err != nil { + t.Error("send failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("send packet count: got = %d, want %d", got, want) + } +} + +func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := send(r, payload); gotErr != wantErr { + t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := sendTo(s, addr, payload); gotErr != wantErr { + t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + 1 + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we do NOT expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { + t.Helper() + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if got := fakeNet.PacketCount(localAddrByte); got != want { + t.Errorf("receive packet count: got = %d, want %d", got, want) } } @@ -312,17 +375,20 @@ func TestNetworkSend(t *testing.T) { t.Fatal("NewNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) } // Make sure that the link-layer endpoint received the outbound packet. - sendTo(t, s, "\x03", nil) - if c := linkEP.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x03", linkEP, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -360,24 +426,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { // Set a route table that sends all packets with odd destination // addresses through the first NIC, and all even destination address // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + }) + } // Send a packet to an odd destination. - sendTo(t, s, "\x05", nil) - - if c := linkEP1.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x05", linkEP1, nil) // Send a packet to an even destination. - sendTo(t, s, "\x06", nil) - - if c := linkEP2.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x06", linkEP2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -439,10 +507,20 @@ func TestRoutes(t *testing.T) { // Set a route table that sends all packets with odd destination // addresses through the first NIC, and all even destination address // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + }) + } // Test routes to odd address. testRoute(t, s, 0, "", "\x05", "\x01") @@ -472,6 +550,10 @@ func TestRoutes(t *testing.T) { } func TestAddressRemoval(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -479,99 +561,285 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet doesn't get delivered - // anymore. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } -func TestDelayedRemovalDueToRoute(t *testing.T) { +func TestAddressRemovalWithRouteHeld(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 2 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet is still deliverable - // because the route is keeping the address alive. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } +} - // Release the route, then check that packet is not deliverable anymore. - r.Release() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) +func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { + t.Helper() + info, ok := s.NICInfo()[nicid] + if !ok { + t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + } + if len(addr) == 0 { + // No address given, verify that there is no address assigned to the NIC. + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + } + } + return + } + // Address given, verify the address is assigned to the NIC and no other + // address is. + found := false + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber { + if a.AddressWithPrefix.Address == addr { + found = true + } else { + t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) + } + } + } + if !found { + t.Errorf("verify address: couldn't find %s on the NIC", addr) + } +} + +func TestEndpointExpiration(t *testing.T) { + const ( + localAddrByte byte = 0x01 + remoteAddr tcpip.Address = "\x03" + noAddr tcpip.Address = "" + nicid tcpip.NICID = 1 + ) + localAddr := tcpip.Address([]byte{localAddrByte}) + + for _, promiscuous := range []bool{true, false} { + for _, spoofing := range []bool{true, false} { + t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) + buf[0] = localAddrByte + + if promiscuous { + if err := s.SetPromiscuousMode(nicid, true); err != nil { + t.Fatal("SetPromiscuousMode failed:", err) + } + } + + if spoofing { + if err := s.SetSpoofing(nicid, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + } + + // 1. No Address yet, send should only work for spoofing, receive for + // promiscuous mode. + //----------------------- + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 2. Add Address, everything should work. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 3. Remove the address, send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 4. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 5. Take a reference to the endpoint by getting a route. Verify that + // we can still send/receive, including sending using the route. + //----------------------- + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 6. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 7. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 8. Remove the route, sendTo/recv should still work. + //----------------------- + r.Release() + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 9. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + }) + } } } @@ -583,9 +851,13 @@ func TestPromiscuousMode(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -593,22 +865,15 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + const localAddrByte byte = 0x01 + buf[0] = localAddrByte + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -621,54 +886,120 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) +} - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) +func TestSpoofingWithAddress(t *testing.T) { + localAddr := tcpip.Address("\x01") + nonExistentLocalAddr := tcpip.Address("\x02") + dstAddr := tcpip.Address("\x03") + + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } + // Sending a packet works. + testSendTo(t, s, dstAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // FindRoute should also work with a local address that exists on the NIC. + r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != localAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } + // Sending a packet using the route works. + testSend(t, r, linkEP, nil) } -func TestAddressSpoofing(t *testing.T) { - srcAddr := tcpip.Address("\x01") +func TestSpoofingNoAddress(t *testing.T) { + nonExistentLocalAddr := tcpip.Address("\x01") dstAddr := tcpip.Address("\x02") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") + id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { - t.Fatal("AddAddress failed:", err) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - // With address spoofing disabled, FindRoute does not permit an address // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err == nil { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } + // Sending a packet fails. + testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. if err := s.SetSpoofing(1, true); err != nil { t.Fatal("SetSpoofing failed:", err) } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - if r.LocalAddress != srcAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } + // Sending a packet works. + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { @@ -806,16 +1137,20 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -824,9 +1159,52 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) +} + +// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +func TestCheckLocalAddressForSubnet(t *testing.T) { + const nicID tcpip.NICID = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) + } + + subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) + + if err != nil { + t.Fatal("NewSubnet failed:", err) + } + if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddSubnet failed:", err) + } + + // Loop over all subnet addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + addr := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) + } + addr[0]++ + } + + // Trying the next address should fail since it is outside the subnet range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) } } @@ -839,16 +1217,20 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -856,10 +1238,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) } func TestNetworkOptions(t *testing.T) { @@ -1213,15 +1592,19 @@ func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatal("CreateNIC failed:", err) + t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) } // Route all packets for address \x01 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\xff", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x01", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } // Send a packet to address 1. buf := buffer.NewView(30) @@ -1236,7 +1619,9 @@ func TestNICStats(t *testing.T) { payload := buffer.NewView(10) // Write a packet out via the address for NIC 1 - sendTo(t, s, "\x01", payload) + if err := sendTo(s, "\x01", payload); err != nil { + t.Fatal("sendTo failed: ", err) + } want := uint64(linkEP1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) @@ -1270,9 +1655,13 @@ func TestNICForwarding(t *testing.T) { } // Route all packets to address 3 to NIC 2. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - }) + { + subnet, err := tcpip.NewSubnet("\x03", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) + } // Send a packet to address 3. buf := buffer.NewView(30) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index eee3144cd..5335897f5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,7 +65,7 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } @@ -79,10 +79,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } -func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -105,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr @@ -279,7 +284,13 @@ func TestTransportReceive(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress failed: %v", err) @@ -335,7 +346,13 @@ func TestTransportControlReceive(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress failed: %v", err) @@ -401,7 +418,13 @@ func TestTransportSend(t *testing.T) { t.Fatalf("AddAddress failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } // Create endpoint and bind it. wq := waiter.Queue{} @@ -492,10 +515,20 @@ func TestTransportForwarding(t *testing.T) { // Route all packets to address 3 to NIC 2 and all packets to address // 1 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - {"\x01", "\xff", "\x00", 1}, - }) + { + subnet0, err := tcpip.NewSubnet("\x03", "\xff") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + }) + } wq := waiter.Queue{} ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 119712d2f..8f9b86cce 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -31,6 +31,7 @@ package tcpip import ( "errors" "fmt" + "math/bits" "reflect" "strconv" "strings" @@ -145,8 +146,17 @@ type Address string type AddressMask string // String implements Stringer. -func (a AddressMask) String() string { - return Address(a).String() +func (m AddressMask) String() string { + return Address(m).String() +} + +// Prefix returns the number of bits before the first host bit. +func (m AddressMask) Prefix() int { + p := 0 + for _, b := range []byte(m) { + p += bits.LeadingZeros8(^b) + } + return p } // Subnet is a subnet defined by its address and mask. @@ -168,6 +178,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) { return Subnet{a, m}, nil } +// String implements Stringer. +func (s Subnet) String() string { + return fmt.Sprintf("%s/%d", s.ID(), s.Prefix()) +} + // Contains returns true iff the address is of the same length and matches the // subnet address and mask. func (s *Subnet) Contains(a Address) bool { @@ -190,28 +205,13 @@ func (s *Subnet) ID() Address { // Bits returns the number of ones (network bits) and zeros (host bits) in the // subnet mask. func (s *Subnet) Bits() (ones int, zeros int) { - for _, b := range []byte(s.mask) { - for i := uint(0); i < 8; i++ { - if b&(1<<i) == 0 { - zeros++ - } else { - ones++ - } - } - } - return + ones = s.mask.Prefix() + return ones, len(s.mask)*8 - ones } // Prefix returns the number of bits before the first host bit. func (s *Subnet) Prefix() int { - for i, b := range []byte(s.mask) { - for j := 7; j >= 0; j-- { - if b&(1<<uint(j)) == 0 { - return i*8 + 7 - j - } - } - } - return len(s.mask) * 8 + return s.mask.Prefix() } // Mask returns the subnet mask. @@ -329,12 +329,12 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error) + Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. - Peek([][]byte) (uintptr, ControlMessages, *Error) + Peek([][]byte) (int64, ControlMessages, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -353,6 +353,9 @@ type Endpoint interface { // ErrAddressFamilyNotSupported must be returned. Connect(address FullAddress) *Error + // Disconnect disconnects the endpoint from its peer. + Disconnect() *Error + // Shutdown closes the read and/or write end of the endpoint connection // to its peer. Shutdown(flags ShutdownFlags) *Error @@ -567,13 +570,8 @@ type BroadcastOption int // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. type Route struct { - // Destination is the address that must be matched against the masked - // target address to check if this row is viable. - Destination Address - - // Mask specifies which bits of the Destination and the target address - // must match for this row to be viable. - Mask AddressMask + // Destination must contain the target address for this row to be viable. + Destination Subnet // Gateway is the gateway to be used if this row is viable. Gateway Address @@ -582,25 +580,15 @@ type Route struct { NIC NICID } -// Match determines if r is viable for the given destination address. -func (r *Route) Match(addr Address) bool { - if len(addr) != len(r.Destination) { - return false - } - - // Using header.Ipv4Broadcast would introduce an import cycle, so - // we'll use a literal instead. - if addr == "\xff\xff\xff\xff" { - return true - } - - for i := 0; i < len(r.Destination); i++ { - if (addr[i] & r.Mask[i]) != r.Destination[i] { - return false - } +// String implements the fmt.Stringer interface. +func (r Route) String() string { + var out strings.Builder + fmt.Fprintf(&out, "%s", r.Destination) + if len(r.Gateway) > 0 { + fmt.Fprintf(&out, " via %s", r.Gateway) } - - return true + fmt.Fprintf(&out, " nic %d", r.NIC) + return out.String() } // LinkEndpointID represents a data link layer endpoint. @@ -1072,6 +1060,11 @@ type AddressWithPrefix struct { PrefixLen int } +// String implements the fmt.Stringer interface. +func (a AddressWithPrefix) String() string { + return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index ebb1c1b56..fb3a0a5ee 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -60,12 +60,12 @@ func TestSubnetBits(t *testing.T) { }{ {"\x00", 0, 8}, {"\x00\x00", 0, 16}, - {"\x36", 4, 4}, - {"\x5c", 4, 4}, - {"\x5c\x5c", 8, 8}, - {"\x5c\x36", 8, 8}, - {"\x36\x5c", 8, 8}, - {"\x36\x36", 8, 8}, + {"\x36", 0, 8}, + {"\x5c", 0, 8}, + {"\x5c\x5c", 0, 16}, + {"\x5c\x36", 0, 16}, + {"\x36\x5c", 0, 16}, + {"\x36\x36", 0, 16}, {"\xff", 8, 0}, {"\xff\xff", 16, 0}, } @@ -122,26 +122,6 @@ func TestSubnetCreation(t *testing.T) { } } -func TestRouteMatch(t *testing.T) { - tests := []struct { - d Address - m AddressMask - a Address - want bool - }{ - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - r := Route{Destination: tt.d, Mask: tt.m} - if got := r.Match(tt.a); got != tt.want { - t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want) - } - } -} - func TestAddressString(t *testing.T) { for _, want := range []string{ // Taken from stdlib. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 9a4306011..451d3880e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -136,34 +136,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (e *endpoint) Resume(s *stack.Stack) { - e.stack = s - - if e.state != stateBound && e.state != stateConnected { - return - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) - if err != nil { - panic(*err) - } - - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) - if err != nil { - panic(*err) - } -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -233,7 +205,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -335,11 +307,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -456,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - nicid := addr.NIC localPort := uint16(0) switch e.state { diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 43551d642..c587b96b6 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,6 +15,7 @@ package icmp import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,3 +65,31 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + if e.state != stateBound && e.state != stateConnected { + return + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + if err != nil { + panic(err) + } + + e.id.LocalAddress = e.route.LocalAddress + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + if err != nil { + panic(err) + } +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eab3dcbd2..13e17e2a6 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -174,31 +174,6 @@ func (ep *endpoint) IPTables() (iptables.IPTables, error) { return ep.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (ep *endpoint) Resume(s *stack.Stack) { - ep.stack = s - - // If the endpoint is connected, re-connect. - if ep.connected { - var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) - if err != nil { - panic(*err) - } - } - - // If the endpoint is bound, re-bind. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(*err) - } -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !ep.associated { @@ -232,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -336,7 +311,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -366,24 +341,24 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintpt return 0, nil, tcpip.ErrUnknownProtocol } - return uintptr(len(payloadBytes)), nil, nil + return int64(len(payloadBytes)), nil, nil } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect implements tcpip.Endpoint.Connect. func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - if ep.closed { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 44abddb2b..168953dec 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,6 +15,7 @@ package raw import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,3 +65,28 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { func (ep *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(ep) } + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (ep *endpoint) Resume(s *stack.Stack) { + ep.stack = s + + // If the endpoint is connected, re-connect. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + if err != nil { + panic(err) + } + } + + // If the endpoint is bound, re-bind. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + panic(err) + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index e67169111..ac927569a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -720,107 +720,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (e *endpoint) Resume(s *stack.Stack) { - e.stack = s - e.segmentQueue.setLimit(MaxUnprocessedSegments) - e.workMu.Init() - - state := e.state - switch state { - case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { - if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) - } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) - } - } - } - - bind := func() { - e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress - } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) - } - } - - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: - bind() - if len(e.connectingAddress) == 0 { - e.connectingAddress = e.id.RemoteAddress - // This endpoint is accepted by netstack but not yet by - // the app. If the endpoint is IPv6 but the remote - // address is IPv4, we need to connect as IPv6 so that - // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress - } - } - // Reset the scoreboard to reinitialize the sack information as - // we do not restore SACK information. - e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectedLoading.Done() - case StateListen: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - bind() - backlog := cap(e.acceptedChan) - if err := e.Listen(backlog); err != nil { - panic("endpoint listening failed: " + err.String()) - } - listenLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateConnecting, StateSynSent, StateSynRecv: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { - panic("endpoint connecting failed: " + err.String()) - } - connectingLoading.Done() - tcpip.AsyncLoading.Done() - }() - case StateBound: - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - tcpip.AsyncLoading.Done() - }() - case StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.state = StateClose - tcpip.AsyncLoading.Done() - }() - } - fallthrough - case StateError: - tcpip.DeleteDanglingEndpoint(e) - } -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() @@ -878,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { return v, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. - - e.mu.RLock() - defer e.mu.RUnlock() - +// isEndpointWritableLocked checks if a given endpoint is writable +// and also returns the number of bytes that can be written at this +// moment. If the endpoint is not writable then it returns an error +// indicating the reason why it's not writable. +// Caller must hold e.mu and e.sndBufMu +func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. if !e.state.connected() { switch e.state { case StateError: - return 0, nil, e.hardError + return 0, e.hardError default: - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } } - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil - } - - e.sndBufMu.Lock() - // Check if the connection has already been closed for sends. if e.sndClosed { - e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } - // Check against the limit. avail := e.sndBufSize - e.sndBufUsed if avail <= 0 { + return 0, tcpip.ErrWouldBlock + } + return avail, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + e.sndBufMu.Lock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrWouldBlock + e.mu.RUnlock() + return 0, nil, err } + e.sndBufMu.Unlock() + e.mu.RUnlock() + + // Nothing to do if the buffer is empty. + if p.Size() == 0 { + return 0, nil, nil + } + + // Copy in memory without holding sndBufMu so that worker goroutine can + // make progress independent of this operation. v, perr := p.Get(avail) if perr != nil { - e.sndBufMu.Unlock() return 0, nil, perr } - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a + // write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due to a + // simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } // Add data to the send queue. + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) e.sndBufUsed += l e.sndBufInQueue += seqnum.Size(l) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() if e.workMu.TryLock() { // Do the work inline. @@ -941,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return uintptr(l), nil, nil + return int64(l), nil, nil } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -973,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // Make a copy of vec so we can modify the slide headers. vec = append([][]byte(nil), vec...) - var num uintptr - + var num int64 for s := e.rcvList.Front(); s != nil; s = s.Next() { views := s.data.Views() @@ -993,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er n := copy(vec[0], v) v = v[n:] vec[0] = vec[0][n:] - num += uintptr(n) + num += int64(n) } } } @@ -1415,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol netProto = header.IPv4ProtocolNumber addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { + if addr.Addr == header.IPv4Any { addr.Addr = "" } } @@ -1429,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - if addr.Addr == "" && addr.Port == 0 { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - return e.connect(addr, true, true) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ef88dc618..831389ec7 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -20,6 +20,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -167,6 +168,107 @@ func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.workMu.Init() + + state := e.state + switch state { + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + } + if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + } + } + + bind := func() { + e.state = StateInitial + if len(e.bindAddress) == 0 { + e.bindAddress = e.id.LocalAddress + } + if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { + panic("endpoint binding failed: " + err.String()) + } + } + + switch state { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + bind() + if len(e.connectingAddress) == 0 { + e.connectingAddress = e.id.RemoteAddress + // This endpoint is accepted by netstack but not yet by + // the app. If the endpoint is IPv6 but the remote + // address is IPv4, we need to connect as IPv6 so that + // dual-stack mode can be properly activated. + if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + } + } + // Reset the scoreboard to reinitialize the sack information as + // we do not restore SACK information. + e.scoreboard.Reset() + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectedLoading.Done() + case StateListen: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + bind() + backlog := cap(e.acceptedChan) + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + listenLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateConnecting, StateSynSent, StateSynRecv: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + bind() + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectingLoading.Done() + tcpip.AsyncLoading.Done() + }() + case StateBound: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + tcpip.AsyncLoading.Done() + }() + case StateClose: + if e.isPortReserved { + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + e.state = StateClose + tcpip.AsyncLoading.Done() + }() + } + fallthrough + case StateError: + tcpip.DeleteDanglingEndpoint(e) + } +} + // saveLastError is invoked by stateify. func (e *endpoint) saveLastError() string { if e.lastError == nil { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 915a98047..f79b8ec5f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2874,15 +2874,11 @@ func makeStack() (*stack.Stack, *tcpip.Error) { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index bcc0f3e28..272481aa0 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -168,15 +168,11 @@ func New(t *testing.T, mtu uint32) *Context { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 7c12a6092..ac5905772 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -178,53 +178,6 @@ func (e *endpoint) IPTables() (iptables.IPTables, error) { return e.stack.IPTables(), nil } -// Resume implements tcpip.ResumableEndpoint.Resume. -func (e *endpoint) Resume(s *stack.Stack) { - e.stack = s - - for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } - - if e.state != stateBound && e.state != stateConnected { - return - } - - netProto := e.effectiveNetProtos[0] - // Connect() and bindLocked() both assert - // - // netProto == header.IPv6ProtocolNumber - // - // before creating a multi-entry effectiveNetProtos. - if len(e.effectiveNetProtos) > 1 { - netProto = header.IPv6ProtocolNumber - } - - var err *tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) - if err != nil { - panic(*err) - } - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) - } - } - - // Our saved state had a port, but we don't actually have a - // reservation. We need to remove the port from our state, but still - // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) - if err != nil { - panic(*err) - } -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -296,6 +249,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // specified address is a multicast address. func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.id.LocalAddress + if isBroadcastOrMulticast(localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { if nicid == 0 { nicid = e.multicastNICID @@ -315,7 +273,7 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -421,11 +379,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -495,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } nicID := v.NIC - if v.InterfaceAddr == header.IPv4Any { + + // The interface address is considered not-set if it is empty or contains + // all-zeros. The former represent the zero-value in golang, the latter the + // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. + allZeros := header.IPv4Any + if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { if nicID == 0 { r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) if err == nil { @@ -739,7 +702,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t netProto = header.IPv4ProtocolNumber addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { + if addr.Addr == header.IPv4Any { addr.Addr = "" } @@ -758,7 +721,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } -func (e *endpoint) disconnect() *tcpip.Error { +// Disconnect implements tcpip.Endpoint.Disconnect. +func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -797,9 +761,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - if addr.Addr == "" { - return e.disconnect() - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -963,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicid := addr.NIC - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. + if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + // A local unicast address was specified, verify that it's valid. nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicid == 0 { return tcpip.ErrBadLocalAddress @@ -1113,3 +1074,7 @@ func (e *endpoint) State() uint32 { // TODO(b/112063468): Translate internal state to values returned by Linux. return 0 } + +func isBroadcastOrMulticast(a tcpip.Address) bool { + return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 86db36260..5cbb56120 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,7 +15,9 @@ package udp import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -64,3 +66,51 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + + if e.state != stateBound && e.state != stateConnected { + return + } + + netProto := e.effectiveNetProtos[0] + // Connect() and bindLocked() both assert + // + // netProto == header.IPv6ProtocolNumber + // + // before creating a multi-entry effectiveNetProtos. + if len(e.effectiveNetProtos) > 1 { + netProto = header.IPv6ProtocolNumber + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + if err != nil { + panic(err) + } + } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + // A local unicast address is specified, verify that it's valid. + if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + id := e.id + e.id.LocalPort = 0 + e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + if err != nil { + panic(err) + } +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 56c285f88..9da6edce2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "fmt" "math" "math/rand" "testing" @@ -34,13 +35,19 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Addresses and ports used for testing. It is recommended that tests stick to +// using these addresses as it allows using the testFlow helper. +// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' +// represents the remote endpoint. const ( + v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + stackV4MappedAddr = v4MappedAddrPrefix + stackAddr + testV4MappedAddr = v4MappedAddrPrefix + testAddr + multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr + broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr + v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" stackAddr = "\x0a\x00\x00\x01" stackPort = 1234 @@ -48,7 +55,7 @@ const ( testPort = 4096 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - multicastPort = 1234 + broadcastAddr = header.IPv4Broadcast // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -56,6 +63,205 @@ const ( defaultMTU = 65536 ) +// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in +// a packet header. These values are used to populate a header or verify one. +// Note that because they are used in packet headers, the addresses are never in +// a V4-mapped format. +type header4Tuple struct { + srcAddr tcpip.FullAddress + dstAddr tcpip.FullAddress +} + +// testFlow implements a helper type used for sending and receiving test +// packets. A given test flow value defines 1) the socket endpoint used for the +// test and 2) the type of packet send or received on the endpoint. E.g., a +// multicastV6Only flow is a V6 multicast packet passing through a V6-only +// endpoint. The type provides helper methods to characterize the flow (e.g., +// isV4) as well as return a proper header4Tuple for it. +type testFlow int + +const ( + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket +) + +func (flow testFlow) String() string { + switch flow { + case unicastV4: + return "unicastV4" + case unicastV6: + return "unicastV6" + case unicastV6Only: + return "unicastV6Only" + case unicastV4in6: + return "unicastV4in6" + case multicastV4: + return "multicastV4" + case multicastV6: + return "multicastV6" + case multicastV6Only: + return "multicastV6Only" + case multicastV4in6: + return "multicastV4in6" + case broadcast: + return "broadcast" + case broadcastIn6: + return "broadcastIn6" + default: + return "unknown" + } +} + +// packetDirection explains if a flow is incoming (read) or outgoing (write). +type packetDirection int + +const ( + incoming packetDirection = iota + outgoing +) + +// header4Tuple returns the header4Tuple for the given flow and direction. Note +// that the tuple contains no mapped addresses as those only exist at the socket +// level but not at the packet header level. +func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { + var h header4Tuple + if flow.isV4() { + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastAddr + } else if flow.isBroadcast() { + h.dstAddr.Addr = broadcastAddr + } + } else { // IPv6 + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastV6Addr + } + } + return h +} + +func (flow testFlow) getMcastAddr() tcpip.Address { + if flow.isV4() { + return multicastAddr + } + return multicastV6Addr +} + +// mapAddrIfApplicable converts the given V4 address into its V4-mapped version +// if it is applicable to the flow. +func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { + if flow.isMapped() { + return v4MappedAddrPrefix + v4Addr + } + return v4Addr +} + +// netProto returns the protocol number used for the network packet. +func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { + if flow.isV4() { + return ipv4.ProtocolNumber + } + return ipv6.ProtocolNumber +} + +// sockProto returns the protocol number used when creating the socket +// endpoint for this flow. +func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { + switch flow { + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + return ipv6.ProtocolNumber + case unicastV4, multicastV4, broadcast: + return ipv4.ProtocolNumber + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { + if flow.isV4() { + return checker.IPv4 + } + return checker.IPv6 +} + +func (flow testFlow) isV6() bool { return !flow.isV4() } +func (flow testFlow) isV4() bool { + return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() +} + +func (flow testFlow) isV6Only() bool { + switch flow { + case unicastV6Only, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMulticast() bool { + switch flow { + case multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isBroadcast() bool { + switch flow { + case broadcast, broadcastIn6: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMapped() bool { + switch flow { + case unicastV4in6, multicastV4in6, broadcastIn6: + return true + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -65,12 +271,9 @@ type testContext struct { wq waiter.Queue } -type headers struct { - srcPort uint16 - dstPort uint16 -} - func newDualTestContext(t *testing.T, mtu uint32) *testContext { + t.Helper() + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) id, linkEP := channel.New(256, mtu, "") @@ -91,15 +294,11 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) @@ -117,51 +316,54 @@ func (c *testContext) cleanup() { } } -func (c *testContext) createV6Endpoint(v6only bool) { +func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { + c.t.Helper() + var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + c.t.Fatal("NewEndpoint failed: ", err) } +} - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) +func (c *testContext) createEndpointForFlow(flow testFlow) { + c.t.Helper() + + c.createEndpoint(flow.sockProto()) + if flow.isV6Only() { + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + } else if flow.isBroadcast() { + if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } } } -func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { +// getPacketAndVerify reads a packet from the link endpoint and verifies the +// header against expected values from the given test flow. In addition, it +// calls any extra checker functions provided. +func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { + c.t.Helper() + select { case p := <-c.linkEP.C: - if p.Proto != protocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) - var srcAddr, dstAddr tcpip.Address - switch protocolNumber { - case ipv4.ProtocolNumber: - checkerFn = checker.IPv4 - srcAddr, dstAddr = stackAddr, testAddr - if multicast { - dstAddr = multicastAddr - } - case ipv6.ProtocolNumber: - checkerFn = checker.IPv6 - srcAddr, dstAddr = stackV6Addr, testV6Addr - if multicast { - dstAddr = multicastV6Addr - } - default: - c.t.Fatalf("unknown protocol %d", protocolNumber) - } - checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) + h := flow.header4Tuple(outgoing) + checkers := append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) return b case <-time.After(2 * time.Second): @@ -171,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult return nil } -func (c *testContext) sendV6Packet(payload []byte, h *headers) { +// injectPacket creates a packet of the given flow and with the given payload, +// and injects it into the link endpoint. +func (c *testContext) injectPacket(flow testFlow, payload []byte) { + c.t.Helper() + + h := flow.header4Tuple(incoming) + if flow.isV4() { + c.injectV4Packet(payload, &h) + } else { + c.injectV6Packet(payload, &h) + } +} + +// injectV6Packet creates a V6 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -182,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -205,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -func (c *testContext) sendPacket(payload []byte, h *headers) { +// injectV6Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -217,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) { TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testAddr, - DstAddr: stackAddr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) ip.SetChecksum(^ip.CalculateChecksum()) // Initialize the UDP header. u := header.UDP(buf[header.IPv4MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -253,7 +472,7 @@ func TestBindPortReuse(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) var eps [5]tcpip.Endpoint reusePortOpt := tcpip.ReusePortOption(1) @@ -296,9 +515,9 @@ func TestBindPortReuse(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort + port, - dstPort: stackPort, + c.injectV6Packet(payload, &header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, }) var addr tcpip.FullAddress @@ -333,13 +552,14 @@ func TestBindPortReuse(t *testing.T) { } } -func testV4Read(c *testContext) { - // Send a packet. +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + payload := newPayload() - c.sendPacket(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) + c.injectPacket(flow, payload) // Try to receive the data. we, ch := waiter.NewChannelEntry(nil) @@ -363,8 +583,9 @@ func testV4Read(c *testContext) { } // Check the peer address. - if addr.Addr != testAddr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + h := flow.header4Tuple(incoming) + if addr.Addr != h.srcAddr.Addr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) } // Check the payload. @@ -377,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { t.Fatalf("ep.Bind(...) failed: %v", err) @@ -388,7 +609,7 @@ func TestBindReservedPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) @@ -447,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -455,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil { + if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -485,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV6ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - // Send a packet. - payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") - } - } - - // Check the peer address. - if addr.Addr != testV6Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) - } + // Test acceptance. + testRead(c, unicastV6) } func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - // Create v4 UDP endpoint. - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpointForFlow(unicastV4) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -555,62 +736,123 @@ func TestV4ReadOnV4(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4) } -func testV4Write(c *testContext) uint16 { - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != nil { - c.t.Fatalf("Write failed: %v", err) +// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast +// address and receive data sent to that address. +func TestReadOnBoundToMulticast(t *testing.T) { + // FIXME(b/128189410): multicastV4in6 currently doesn't work as + // AddMembershipOption doesn't handle V4in6 addresses. + for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to multicast address. + mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) + if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + // Join multicast group. + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } + + testRead(c, flow) + }) } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast +// address and receive broadcast data on it. +func TestV4ReadOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to broadcast address. + bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) + if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// testFailingWrite sends a packet of the given test flow into the UDP endpoint +// and verifies it fails with the provided error code. +func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { + c.t.Helper() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + + payload := buffer.View(newPayload()) + _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + }) + if gotErr != wantErr { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } +} - return udp.SourcePort() +// testWrite sends a packet of the given test flow from the UDP endpoint to the +// flow's destination address:port. It then receives it from the link endpoint +// and verifies its correctness including any additional checker functions +// provided. +func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, true, checkers...) } -func testV6Write(c *testContext) uint16 { - // Write to v6 address. +// testWriteWithoutDestination sends a packet of the given test flow from the +// UDP endpoint without giving a destination address:port. It then receives it +// from the link endpoint and verifies its correctness including any additional +// checker functions provided. +func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, false, checkers...) +} + +func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + + writeOpts := tcpip.WriteOptions{} + if setDest { + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + writeOpts = tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + } + } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("Write failed: %v", err) } - if n != uintptr(len(payload)) { + if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. + // Received the packet and check the payload. + b := c.getPacketAndVerify(flow, checkers...) + var udp header.UDP + if flow.isV4() { + udp = header.UDP(header.IPv4(b).Payload()) + } else { + udp = header.UDP(header.IPv6(b).Payload()) + } if !bytes.Equal(payload, udp.Payload()) { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } @@ -619,8 +861,10 @@ func testV6Write(c *testContext) uint16 { } func testDualWrite(c *testContext) uint16 { - v4Port := testV4Write(c) - v6Port := testV6Write(c) + c.t.Helper() + + v4Port := testWrite(c, unicastV4in6) + v6Port := testWrite(c, unicastV6) if v4Port != v6Port { c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) } @@ -632,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) } @@ -641,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -658,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV6Write(c) + testWrite(c, unicastV6) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNetworkUnreachable { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) } func TestDualWriteConnectedToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Write(c) + testWrite(c, unicastV4in6) // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV4WriteOnV6Only(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(true) + c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNoRoute { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -728,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV6WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } + testWriteWithoutDestination(c, unicastV6) } func TestV4WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) + testWriteWithoutDestination(c, unicastV4) +} + +// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket +// that is bound to a V4 multicast address. +func TestWriteOnBoundToV4Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a +// socket that is bound to a V4-mapped multicast address. +func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV6, multicastV6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// V6-only socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToBroadcast checks that we can send packets out of a +// socket that is bound to the broadcast address. +func TestWriteOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 broadcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a +// socket that is bound to the V4-mapped broadcast address. +func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } } @@ -814,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { defer c.cleanup() // Create IPv4 UDP endpoint - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Read(c) + testRead(c, unicastV4) var want uint64 = 1 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { @@ -837,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) @@ -847,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } -func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) { - for _, name := range []string{"v4", "v6", "dual"} { - t.Run(name, func(t *testing.T) { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch name { - case "v4": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } +func TestTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - var variants []string - switch name { - case "v4": - variants = []string{"v4"} - case "v6": - variants = []string{"v6"} - case "dual": - variants = []string{"v6", "mapped"} - } + c.createEndpointForFlow(flow) - for _, variant := range variants { - t.Run(variant, func(t *testing.T) { - optFunc(t, name, networkProtocolNumber, variant) - }) + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) } - }) - } -} -func TestTTL(t *testing.T) { - payload := tcpip.SlicePayload(buffer.View(newPayload())) - - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, typ := range []string{"unicast", "multicast"} { - t.Run(typ, func(t *testing.T) { - var addr tcpip.Address - var port uint16 - switch typ { - case "unicast": - port = testPort - switch variant { - case "v4": - addr = testAddr - case "mapped": - addr = testV4MappedAddr - case "v6": - addr = testV6Addr - default: - t.Fatal("unknown test variant") - } - case "multicast": - port = multicastPort - switch variant { - case "v4": - addr = multicastAddr - case "mapped": - addr = multicastV4MappedAddr - case "v6": - addr = multicastV6Addr - default: - t.Fatal("unknown test variant") - } - default: - t.Fatal("unknown test variant") + var wantTTL uint8 + if flow.isMulticast() { + wantTTL = multicastTTL + } else { + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - switch name { - case "v4": - case "v6": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - case "dual": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - default: - t.Fatal("unknown test variant") - } - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + t.Fatal(err) } + wantTTL = ep.DefaultTTL() + ep.Close() + } - n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) - } + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } +} - checkerFn := checker.IPv4 - switch variant { - case "v4", "mapped": - case "v6": - checkerFn = checker.IPv6 - default: - t.Fatal("unknown test variant") - } - var wantTTL uint8 - var multicast bool - switch typ { - case "unicast": - multicast = false - switch variant { - case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - default: - t.Fatal("unknown test variant") - } - case "multicast": - wantTTL = multicastTTL - multicast = true - default: - t.Fatal("unknown test variant") - } +func TestMulticastInterfaceOption(t *testing.T) { + for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, bindTyp := range []string{"bound", "unbound"} { + t.Run(bindTyp, func(t *testing.T) { + for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { + t.Run(optTyp, func(t *testing.T) { + h := flow.header4Tuple(outgoing) + mcastAddr := h.dstAddr.Addr + localIfAddr := h.srcAddr.Addr + + var ifoptSet tcpip.MulticastInterfaceOption + switch optTyp { + case "use local-addr": + ifoptSet.InterfaceAddr = localIfAddr + case "use NICID": + ifoptSet.NIC = 1 + case "use local-addr and NIC": + ifoptSet.InterfaceAddr = localIfAddr + ifoptSet.NIC = 1 + default: + t.Fatal("unknown test variant") + } - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch variant { - case "v4", "mapped": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(flow.sockProto()) + + if bindTyp == "bound" { + // Bind the socket by connecting to the multicast address. + // This may have an influence on how the multicast interface + // is set. + addr := tcpip.FullAddress{ + Addr: flow.mapAddrIfApplicable(mcastAddr), + Port: stackPort, + } + if err := c.ep.Connect(addr); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + } - b := c.getPacket(networkProtocolNumber, multicast) - checkerFn(c.t, b, - checker.TTL(wantTTL), - checker.UDP( - checker.DstPort(port), - ), - ) - }) - } - }) -} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } -func TestMulticastInterfaceOption(t *testing.T) { - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - var mcastAddr, localIfAddr tcpip.Address - switch variant { - case "v4": - mcastAddr = multicastAddr - localIfAddr = stackAddr - case "mapped": - mcastAddr = multicastV4MappedAddr - localIfAddr = stackAddr - case "v6": - mcastAddr = multicastV6Addr - localIfAddr = stackV6Addr - default: - t.Fatal("unknown test variant") - } - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: mcastAddr, - Port: multicastPort, + // Verify multicast interface addr and NIC were set correctly. + // Note that NIC must be 1 since this is our outgoing interface. + ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} + var ifoptGot tcpip.MulticastInterfaceOption + if err := c.ep.GetSockOpt(&ifoptGot); err != nil { + c.t.Fatalf("GetSockOpt failed: %v", err) } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + if ifoptGot != ifoptWant { + c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) + }) + } + }) + } + }) + } } |