diff options
Diffstat (limited to 'pkg/tcpip')
60 files changed, 2698 insertions, 1483 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 047f8329a..df37c7d5a 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -12,6 +12,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/buffer", + "//pkg/tcpip/iptables", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index c40924852..0d2637ee4 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -24,6 +24,7 @@ go_test( embed = [":gonet"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 308f620e5..cd6ce930a 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -404,7 +404,7 @@ func (c *Conn) Write(b []byte) (int, error) { } } - var n uintptr + var n int64 var resCh <-chan struct{} n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) nbytes += int(n) @@ -556,32 +556,50 @@ type PacketConn struct { wq *waiter.Queue } -// NewPacketConn creates a new PacketConn. -func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { - // Create UDP endpoint and bind it. +// DialUDP creates a new PacketConn. +// +// If laddr is nil, a local address is automatically chosen. +// +// If raddr is nil, the PacketConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) if err != nil { return nil, errors.New(err.String()) } - if err := ep.Bind(addr); err != nil { - ep.Close() - return nil, &net.OpError{ - Op: "bind", - Net: "udp", - Addr: fullToUDPAddr(addr), - Err: errors.New(err.String()), + if laddr != nil { + if err := ep.Bind(*laddr); err != nil { + ep.Close() + return nil, &net.OpError{ + Op: "bind", + Net: "udp", + Addr: fullToUDPAddr(*laddr), + Err: errors.New(err.String()), + } } } - c := &PacketConn{ + c := PacketConn{ stack: s, ep: ep, wq: &wq, } c.deadlineTimer.init() - return c, nil + + if raddr != nil { + if err := c.ep.Connect(*raddr); err != nil { + c.ep.Close() + return nil, &net.OpError{ + Op: "connect", + Net: "udp", + Addr: fullToUDPAddr(*raddr), + Err: errors.New(err.String()), + } + } + } + + return &c, nil } func (c *PacketConn) newOpError(op string, err error) *net.OpError { diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 39efe44c7..672f026b2 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -26,6 +26,7 @@ import ( "golang.org/x/net/nettest" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -69,17 +70,13 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) { s.SetRouteTable([]tcpip.Route{ // IPv4 { - Destination: tcpip.Address(strings.Repeat("\x00", 4)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 4)), - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: NICID, }, // IPv6 { - Destination: tcpip.Address(strings.Repeat("\x00", 16)), - Mask: tcpip.AddressMask(strings.Repeat("\x00", 16)), - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: NICID, }, }) @@ -371,9 +368,9 @@ func TestUDPForwarder(t *testing.T) { }) s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket) - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } sent := "abc123" @@ -452,13 +449,13 @@ func TestPacketConnTransfer(t *testing.T) { addr2 := tcpip.FullAddress{NICID, ip2, 11311} s.AddAddress(NICID, ipv4.ProtocolNumber, ip2) - c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber) + c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 4):", err) + t.Fatal("DialUDP(bind port 4):", err) } - c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber) + c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber) if err != nil { - t.Fatal("NewPacketConn(port 5):", err) + t.Fatal("DialUDP(bind port 5):", err) } c1.SetDeadline(time.Now().Add(time.Second)) @@ -491,6 +488,50 @@ func TestPacketConnTransfer(t *testing.T) { } } +func TestConnectedPacketConnTransfer(t *testing.T) { + s, e := newLoopbackStack() + if e != nil { + t.Fatalf("newLoopbackStack() = %v", e) + } + + ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) + addr := tcpip.FullAddress{NICID, ip, 11211} + s.AddAddress(NICID, ipv4.ProtocolNumber, ip) + + c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 4):", err) + } + c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber) + if err != nil { + t.Fatal("DialUDP(bind port 5):", err) + } + + c1.SetDeadline(time.Now().Add(time.Second)) + c2.SetDeadline(time.Now().Add(time.Second)) + + sent := "abc123" + if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) { + t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil) + } + recv := make([]byte, len(sent)) + n, err := c1.Read(recv) + if err != nil || n != len(recv) { + t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil) + } + + if recv := string(recv); recv != sent { + t.Errorf("got recv = %q, want = %q", recv, sent) + } + + if err := c1.Close(); err != nil { + t.Error("c1.Close():", err) + } + if err := c2.Close(); err != nil { + t.Error("c2.Close():", err) + } +} + func makePipe() (c1, c2 net.Conn, stop func(), err error) { s, e := newLoopbackStack() if e != nil { diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 94a3af289..17fc9c68e 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -111,6 +111,15 @@ const ( IPv4FlagDontFragment ) +// IPv4EmptySubnet is the empty IPv4 subnet. +var IPv4EmptySubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any)) + if err != nil { + panic(err) + } + return subnet +}() + // IPVersion returns the version of IP used in the given packet. It returns -1 // if the packet is not large enough to contain the version field. func IPVersion(b []byte) int { diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 95fe8bfc3..bc4e56535 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -27,7 +27,7 @@ const ( nextHdr = 6 hopLimit = 7 v6SrcAddr = 8 - v6DstAddr = 24 + v6DstAddr = v6SrcAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -82,6 +82,15 @@ const ( IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ) +// IPv6EmptySubnet is the empty IPv6 subnet. +var IPv6EmptySubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) + if err != nil { + panic(err) + } + return subnet +}() + // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { @@ -110,13 +119,13 @@ func (b IPv6) Payload() []byte { // SourceAddress returns the "source address" field of the ipv6 header. func (b IPv6) SourceAddress() tcpip.Address { - return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) + return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize]) } // DestinationAddress returns the "destination address" field of the ipv6 // header. func (b IPv6) DestinationAddress() tcpip.Address { - return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) + return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize]) } // Checksum implements Network.Checksum. Given that IPv6 doesn't have a @@ -144,13 +153,13 @@ func (b IPv6) SetPayloadLength(payloadLength uint16) { // SetSourceAddress sets the "source address" field of the ipv6 header. func (b IPv6) SetSourceAddress(addr tcpip.Address) { - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) + copy(b[v6SrcAddr:][:IPv6AddressSize], addr) } // SetDestinationAddress sets the "destination address" field of the ipv6 // header. func (b IPv6) SetDestinationAddress(addr tcpip.Address) { - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) + copy(b[v6DstAddr:][:IPv6AddressSize], addr) } // SetNextHeader sets the value of the "next header" field of the ipv6 header. @@ -169,8 +178,8 @@ func (b IPv6) Encode(i *IPv6Fields) { b.SetPayloadLength(i.PayloadLength) b[nextHdr] = i.NextHeader b[hopLimit] = i.HopLimit - copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) - copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) + b.SetSourceAddress(i.SrcAddr) + b.SetDestinationAddress(i.DstAddr) } // IsValid performs basic validation on the packet. diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index fc9abbb55..3fc14bacd 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -11,8 +11,5 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - ], + deps = ["//pkg/tcpip/buffer"], ) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index f1e1d1fad..68c68d4aa 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -32,8 +32,8 @@ const ( // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { - return &IPTables{ +func DefaultTables() IPTables { + return IPTables{ Tables: map[string]Table{ tablenameNat: Table{ BuiltinChains: map[Hook]Chain{ diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 600bd9a10..42a79ef9f 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -15,7 +15,6 @@ package iptables import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -128,15 +127,29 @@ type Table struct { // UserChains, and its purpose is to make looking up tables by name // fast. Chains map[string]*Chain + + // Metadata holds information about the Table that is useful to users + // of IPTables, but not to the netstack IPTables code itself. + metadata interface{} } // ValidHooks returns a bitmap of the builtin hooks for the given table. -func (table *Table) ValidHooks() (uint32, *tcpip.Error) { +func (table *Table) ValidHooks() uint32 { hooks := uint32(0) for hook, _ := range table.BuiltinChains { hooks |= 1 << hook } - return hooks, nil + return hooks +} + +// Metadata returns the metadata object stored in table. +func (table *Table) Metadata() interface{} { + return table.metadata +} + +// SetMetadata sets the metadata object stored in table. +func (table *Table) SetMetadata(metadata interface{}) { + table.metadata = metadata } // A Chain defines a list of rules for packet processing. When a packet diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index d786d8fdf..74fbbb896 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -8,8 +8,8 @@ go_library( "endpoint.go", "endpoint_unsafe.go", "mmap.go", - "mmap_amd64.go", - "mmap_amd64_unsafe.go", + "mmap_stub.go", + "mmap_unsafe.go", "packet_dispatchers.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index fe19c2bc2..8bfeb97e4 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -12,14 +12,183 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !linux !amd64 +// +build linux,amd64 linux,arm64 package fdbased -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + "syscall" -// Stubbed out version for non-linux/non-amd64 platforms. + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" +) -func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, *tcpip.Error) { - return nil, nil +const ( + tPacketAlignment = uintptr(16) + tpStatusKernel = 0 + tpStatusUser = 1 + tpStatusCopy = 2 + tpStatusLosing = 4 +) + +// We overallocate the frame size to accommodate space for the +// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding. +// +// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB +// +// NOTE: +// Frames need to be aligned at 16 byte boundaries. +// BlockSize needs to be page aligned. +// +// For details see PACKET_MMAP setting constraints in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +const ( + tpFrameSize = 65536 + 128 + tpBlockSize = tpFrameSize * 32 + tpBlockNR = 1 + tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize +) + +// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct +// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>. +func tPacketAlign(v uintptr) uintptr { + return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1)) +} + +// tPacketReq is the tpacket_req structure as described in +// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +type tPacketReq struct { + tpBlockSize uint32 + tpBlockNR uint32 + tpFrameSize uint32 + tpFrameNR uint32 +} + +// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h> +type tPacketHdr []byte + +const ( + tpStatusOffset = 0 + tpLenOffset = 8 + tpSnapLenOffset = 12 + tpMacOffset = 16 + tpNetOffset = 18 + tpSecOffset = 20 + tpUSecOffset = 24 +) + +func (t tPacketHdr) tpLen() uint32 { + return binary.LittleEndian.Uint32(t[tpLenOffset:]) +} + +func (t tPacketHdr) tpSnapLen() uint32 { + return binary.LittleEndian.Uint32(t[tpSnapLenOffset:]) +} + +func (t tPacketHdr) tpMac() uint16 { + return binary.LittleEndian.Uint16(t[tpMacOffset:]) +} + +func (t tPacketHdr) tpNet() uint16 { + return binary.LittleEndian.Uint16(t[tpNetOffset:]) +} + +func (t tPacketHdr) tpSec() uint32 { + return binary.LittleEndian.Uint32(t[tpSecOffset:]) +} + +func (t tPacketHdr) tpUSec() uint32 { + return binary.LittleEndian.Uint32(t[tpUSecOffset:]) +} + +func (t tPacketHdr) Payload() []byte { + return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()] +} + +// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. +// See: mmap_amd64_unsafe.go for implementation details. +type packetMMapDispatcher struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // ringBuffer is only used when PacketMMap dispatcher is used and points + // to the start of the mmapped PACKET_RX_RING buffer. + ringBuffer []byte + + // ringOffset is the current offset into the ring buffer where the next + // inbound packet will be placed by the kernel. + ringOffset int +} + +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { + hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) + for hdr.tpStatus()&tpStatusUser == 0 { + event := rawfile.PollEvent{ + FD: int32(d.fd), + Events: unix.POLLIN | unix.POLLERR, + } + if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 { + if errno == syscall.EINTR { + continue + } + return nil, rawfile.TranslateErrno(errno) + } + if hdr.tpStatus()&tpStatusCopy != 0 { + // This frame is truncated so skip it after flipping the + // buffer to the kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:]) + continue + } + } + + // Copy out the packet from the mmapped frame to a locally owned buffer. + pkt := make([]byte, hdr.tpSnapLen()) + copy(pkt, hdr.Payload()) + // Release packet to kernel. + hdr.setTPStatus(tpStatusKernel) + d.ringOffset = (d.ringOffset + 1) % tpFrameNR + return pkt, nil +} + +// dispatch reads packets from an mmaped ring buffer and dispatches them to the +// network stack. +func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { + pkt, err := d.readMMappedPacket() + if err != nil { + return false, err + } + var ( + p tcpip.NetworkProtocolNumber + remote, local tcpip.LinkAddress + ) + if d.e.hdrSize > 0 { + eth := header.Ethernet(pkt) + p = eth.Type() + remote = eth.SourceAddress() + local = eth.DestinationAddress() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + switch header.IPVersion(pkt) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + return true, nil + } + } + + pkt = pkt[d.e.hdrSize:] + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) + return true, nil } diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go deleted file mode 100644 index 8bbb4f9ab..000000000 --- a/pkg/tcpip/link/fdbased/mmap_amd64.go +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux,amd64 - -package fdbased - -import ( - "encoding/binary" - "syscall" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" -) - -const ( - tPacketAlignment = uintptr(16) - tpStatusKernel = 0 - tpStatusUser = 1 - tpStatusCopy = 2 - tpStatusLosing = 4 -) - -// We overallocate the frame size to accommodate space for the -// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding. -// -// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB -// -// NOTE: -// Frames need to be aligned at 16 byte boundaries. -// BlockSize needs to be page aligned. -// -// For details see PACKET_MMAP setting constraints in -// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt -const ( - tpFrameSize = 65536 + 128 - tpBlockSize = tpFrameSize * 32 - tpBlockNR = 1 - tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize -) - -// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct -// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>. -func tPacketAlign(v uintptr) uintptr { - return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1)) -} - -// tPacketReq is the tpacket_req structure as described in -// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt -type tPacketReq struct { - tpBlockSize uint32 - tpBlockNR uint32 - tpFrameSize uint32 - tpFrameNR uint32 -} - -// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h> -type tPacketHdr []byte - -const ( - tpStatusOffset = 0 - tpLenOffset = 8 - tpSnapLenOffset = 12 - tpMacOffset = 16 - tpNetOffset = 18 - tpSecOffset = 20 - tpUSecOffset = 24 -) - -func (t tPacketHdr) tpLen() uint32 { - return binary.LittleEndian.Uint32(t[tpLenOffset:]) -} - -func (t tPacketHdr) tpSnapLen() uint32 { - return binary.LittleEndian.Uint32(t[tpSnapLenOffset:]) -} - -func (t tPacketHdr) tpMac() uint16 { - return binary.LittleEndian.Uint16(t[tpMacOffset:]) -} - -func (t tPacketHdr) tpNet() uint16 { - return binary.LittleEndian.Uint16(t[tpNetOffset:]) -} - -func (t tPacketHdr) tpSec() uint32 { - return binary.LittleEndian.Uint32(t[tpSecOffset:]) -} - -func (t tPacketHdr) tpUSec() uint32 { - return binary.LittleEndian.Uint32(t[tpUSecOffset:]) -} - -func (t tPacketHdr) Payload() []byte { - return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()] -} - -// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. -// See: mmap_amd64_unsafe.go for implementation details. -type packetMMapDispatcher struct { - // fd is the file descriptor used to send and receive packets. - fd int - - // e is the endpoint this dispatcher is attached to. - e *endpoint - - // ringBuffer is only used when PacketMMap dispatcher is used and points - // to the start of the mmapped PACKET_RX_RING buffer. - ringBuffer []byte - - // ringOffset is the current offset into the ring buffer where the next - // inbound packet will be placed by the kernel. - ringOffset int -} - -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { - hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) - for hdr.tpStatus()&tpStatusUser == 0 { - event := rawfile.PollEvent{ - FD: int32(d.fd), - Events: unix.POLLIN | unix.POLLERR, - } - if _, errno := rawfile.BlockingPoll(&event, 1, -1); errno != 0 { - if errno == syscall.EINTR { - continue - } - return nil, rawfile.TranslateErrno(errno) - } - if hdr.tpStatus()&tpStatusCopy != 0 { - // This frame is truncated so skip it after flipping the - // buffer to the kernel. - hdr.setTPStatus(tpStatusKernel) - d.ringOffset = (d.ringOffset + 1) % tpFrameNR - hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:]) - continue - } - } - - // Copy out the packet from the mmapped frame to a locally owned buffer. - pkt := make([]byte, hdr.tpSnapLen()) - copy(pkt, hdr.Payload()) - // Release packet to kernel. - hdr.setTPStatus(tpStatusKernel) - d.ringOffset = (d.ringOffset + 1) % tpFrameNR - return pkt, nil -} - -// dispatch reads packets from an mmaped ring buffer and dispatches them to the -// network stack. -func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { - pkt, err := d.readMMappedPacket() - if err != nil { - return false, err - } - var ( - p tcpip.NetworkProtocolNumber - remote, local tcpip.LinkAddress - ) - if d.e.hdrSize > 0 { - eth := header.Ethernet(pkt) - p = eth.Type() - remote = eth.SourceAddress() - local = eth.DestinationAddress() - } else { - // We don't get any indication of what the packet is, so try to guess - // if it's an IPv4 or IPv6 packet. - switch header.IPVersion(pkt) { - case header.IPv4Version: - p = header.IPv4ProtocolNumber - case header.IPv6Version: - p = header.IPv6ProtocolNumber - default: - return true, nil - } - } - - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) - return true, nil -} diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go new file mode 100644 index 000000000..67be52d67 --- /dev/null +++ b/pkg/tcpip/link/fdbased/mmap_stub.go @@ -0,0 +1,23 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !linux !amd64,!arm64 + +package fdbased + +// Stubbed out version for non-linux/non-amd64/non-arm64 platforms. + +func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + return nil, nil +} diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go index 47cb1d1cc..3894185ae 100644 --- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go +++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,amd64 +// +build linux,amd64 linux,arm64 package fdbased diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 6e3a7a9d7..088eb8a21 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -6,8 +6,10 @@ go_library( name = "rawfile", srcs = [ "blockingpoll_amd64.s", - "blockingpoll_amd64_unsafe.go", + "blockingpoll_arm64.s", + "blockingpoll_noyield_unsafe.go", "blockingpoll_unsafe.go", + "blockingpoll_yield_unsafe.go", "errors.go", "rawfile_unsafe.go", ], diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s index b54131573..298bad55d 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s +++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s @@ -14,17 +14,18 @@ #include "textflag.h" -// BlockingPoll makes the poll() syscall while calling the version of +// BlockingPoll makes the ppoll() syscall while calling the version of // entersyscall that relinquishes the P so that other Gs can run. This is meant // to be called in cases when the syscall is expected to block. // -// func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (n int, err syscall.Errno) +// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno) TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 CALL ·callEntersyscallblock(SB) MOVQ fds+0(FP), DI MOVQ nfds+8(FP), SI MOVQ timeout+16(FP), DX - MOVQ $0x7, AX // SYS_POLL + MOVQ $0x0, R10 // sigmask parameter which isn't used here + MOVQ $0x10f, AX // SYS_PPOLL SYSCALL CMPQ AX, $0xfffffffffffff001 JLS ok diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s new file mode 100644 index 000000000..b62888b93 --- /dev/null +++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s @@ -0,0 +1,42 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// BlockingPoll makes the ppoll() syscall while calling the version of +// entersyscall that relinquishes the P so that other Gs can run. This is meant +// to be called in cases when the syscall is expected to block. +// +// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno) +TEXT ·BlockingPoll(SB),NOSPLIT,$0-40 + BL ·callEntersyscallblock(SB) + MOVD fds+0(FP), R0 + MOVD nfds+8(FP), R1 + MOVD timeout+16(FP), R2 + MOVD $0x0, R3 // sigmask parameter which isn't used here + MOVD $0x49, R8 // SYS_PPOLL + SVC + CMP $0xfffffffffffff001, R0 + BLS ok + MOVD $-1, R1 + MOVD R1, n+24(FP) + NEG R0, R0 + MOVD R0, err+32(FP) + BL ·callExitsyscall(SB) + RET +ok: + MOVD R0, n+24(FP) + MOVD $0, err+32(FP) + BL ·callExitsyscall(SB) + RET diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go new file mode 100644 index 000000000..621ab8d29 --- /dev/null +++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go @@ -0,0 +1,31 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux,!amd64,!arm64 + +package rawfile + +import ( + "syscall" + "unsafe" +) + +// BlockingPoll is just a stub function that forwards to the ppoll() system call +// on non-amd64 and non-arm64 platforms. +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) { + n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)), + uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0) + + return int(n), e +} diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go index 4eab77c74..84dc0e918 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go @@ -21,9 +21,11 @@ import ( "unsafe" ) -// BlockingPoll is just a stub function that forwards to the poll() system call +// BlockingPoll is just a stub function that forwards to the ppoll() system call // on non-amd64 platforms. -func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) { - n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout)) +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) { + n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)), + uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0) + return int(n), e } diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index c87268610..dda3b10a6 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux,amd64 +// +build linux,amd64 linux,arm64 // +build go1.12 // +build !go1.14 @@ -25,8 +25,14 @@ import ( _ "unsafe" // for go:linkname ) +// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the +// version of entersyscall that relinquishes the P so that other Gs can +// run. This is meant to be called in cases when the syscall is expected to +// block. On non amd64/arm64 platforms it just forwards to the ppoll() system +// call. +// //go:noescape -func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) +func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) // Use go:linkname to call into the runtime. As of Go 1.12 this has to // be done from Go code so that we make an ABIInternal call to an diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index e3fbb15c2..7e286a3a6 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -123,7 +123,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { Events: 1, // POLLIN } - _, e = BlockingPoll(&event, 1, -1) + _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != syscall.EINTR { return 0, TranslateErrno(e) } @@ -145,7 +145,7 @@ func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { Events: 1, // POLLIN } - _, e = BlockingPoll(&event, 1, -1) + _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != syscall.EINTR { return 0, TranslateErrno(e) } @@ -175,7 +175,7 @@ func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { Events: 1, // POLLIN } - if _, e := BlockingPoll(&event, 1, -1); e != 0 && e != syscall.EINTR { + if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != syscall.EINTR { return 0, TranslateErrno(e) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index fc584c6a4..36c8c46fc 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -360,10 +360,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() + details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) + size -= header.UDPMinimumSize } - size -= header.UDPMinimumSize - - details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) case header.TCPProtocolNumber: transName = "tcp" diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 60070874d..fd6395fc1 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -109,13 +109,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { pkt.SetOp(header.ARPReply) copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) - fallthrough // also fill the cache from requests case header.ARPReply: - addr := tcpip.Address(h.ProtocolAddressSender()) - linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 66c55821b..4c4b54469 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,6 +15,7 @@ package arp_test import ( + "strconv" "testing" "time" @@ -65,9 +66,7 @@ func newTestContext(t *testing.T) *testContext { } s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }}) @@ -101,40 +100,30 @@ func TestDirectRequest(t *testing.T) { c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) } - inject(stackAddr1) - { - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) - } - rep := header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { - t.Errorf("stackAddr1: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) - } - } - - inject(stackAddr2) - { - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) - } - rep := header.ARP(pkt.Header) - if !rep.IsValid() { - t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) - } - if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { - t.Errorf("stackAddr2: expected sender to be set") - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { - t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) - } + for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { + t.Run(strconv.Itoa(i), func(t *testing.T) { + inject(address) + pkt := <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) + } + }) } inject(stackAddrBad) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 55e9eec99..6bbfcd97f 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -173,8 +173,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ - Destination: ipv4SubnetAddr, - Mask: ipv4SubnetMask, + Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, NIC: 1, }}) @@ -187,8 +186,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ - Destination: ipv6SubnetAddr, - Mask: ipv6SubnetMask, + Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, NIC: 1, }}) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index fbef6947d..497164cbb 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -94,6 +94,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv4EchoReply) + pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 3207a3d46..1b5a55bea 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -52,9 +52,7 @@ func TestExcludeBroadcast(t *testing.T) { } s.SetRouteTable([]tcpip.Route{{ - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }}) @@ -247,14 +245,22 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) linkEPId := stack.RegisterLinkEndpoint(linkEP) s.CreateNIC(1, linkEPId) - s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01") - s.SetRouteTable([]tcpip.Route{{ - Destination: "\x10\x00\x00\x02", - Mask: "\xff\xff\xff\xff", - Gateway: "", - NIC: 1, - }}) - r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */) + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + s.AddAddress(1, ipv4.ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("s.FindRoute got %v, want %v", err, nil) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 5e6a59e91..1689af16f 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -100,13 +100,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return @@ -146,7 +144,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:16]) + targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 726362c87..d0dc72506 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -91,13 +91,18 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } } - s.SetRouteTable( - []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), - NIC: 1, - }}, - ) + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } netProto := s.NetworkProtocolInstance(ProtocolNumber) if netProto == nil { @@ -237,17 +242,23 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress sn lladdr1: %v", err) } + subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } c.s0.SetRouteTable( []tcpip.Route{{ - Destination: lladdr1, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + Destination: subnet0, NIC: 1, }}, ) + subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } c.s1.SetRouteTable( []tcpip.Route{{ - Destination: lladdr0, - Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + Destination: subnet1, NIC: 1, }}, ) diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index 996939581..a57752a7c 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -8,6 +8,7 @@ go_binary( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3ac381631..e2021cd15 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -52,6 +52,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -152,9 +153,7 @@ func main() { // Add default route. s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index da425394a..1716be285 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -149,12 +149,15 @@ func main() { log.Fatal(err) } + subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) + if err != nil { + log.Fatal(err) + } + // Add default route. s.SetRouteTable([]tcpip.Route{ { - Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), - Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), - Gateway: "", + Destination: subnet, NIC: 1, }, }) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 28d11c797..b692c60ce 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,11 +1,25 @@ package(licenses = ["notice"]) +load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +go_template_instance( + name = "linkaddrentry_list", + out = "linkaddrentry_list.go", + package = "stack", + prefix = "linkAddrEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*linkAddrEntry", + "Linker": "*linkAddrEntry", + }, +) + go_library( name = "stack", srcs = [ "linkaddrcache.go", + "linkaddrentry_list.go", "nic.go", "registration.go", "route.go", @@ -24,6 +38,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", @@ -42,6 +57,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/waiter", @@ -58,3 +74,11 @@ go_test( "//pkg/tcpip", ], ) + +filegroup( + name = "autogen", + srcs = [ + "linkaddrentry_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 77bb0ccb9..267df60d1 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -42,10 +42,11 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - mu sync.Mutex - cache map[tcpip.FullAddress]*linkAddrEntry - next int // array index of next available entry - entries [linkAddrCacheSize]linkAddrEntry + cache struct { + sync.Mutex + table map[tcpip.FullAddress]*linkAddrEntry + lru linkAddrEntryList + } } // entryState controls the state of a single entry in the cache. @@ -60,9 +61,6 @@ const ( // failed means that address resolution timed out and the address // could not be resolved. failed - // expired means that the cache entry has expired and the address must be - // resolved again. - expired ) // String implements Stringer. @@ -74,8 +72,6 @@ func (s entryState) String() string { return "ready" case failed: return "failed" - case expired: - return "expired" default: return fmt.Sprintf("unknown(%d)", s) } @@ -84,64 +80,46 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + linkAddrEntryEntry + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of 'incomplete' these waiters are notified. + // state transitions out of incomplete these waiters are notified. wakers map[*sleep.Waker]struct{} + // done is used to allow callers to wait on address resolution. It is nil iff + // s is incomplete and resolution is not yet in progress. done chan struct{} } -func (e *linkAddrEntry) state() entryState { - if e.s != expired && time.Now().After(e.expiration) { - // Force the transition to ensure waiters are notified. - e.changeState(expired) - } - return e.s -} - -func (e *linkAddrEntry) changeState(ns entryState) { - if e.s == ns { - return - } - - // Validate state transition. - switch e.s { - case incomplete: - // All transitions are valid. - case ready, failed: - if ns != expired { - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - } - case expired: - // Terminal state. - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - default: - panic(fmt.Sprintf("invalid state: %s", e.s)) - } - +// changeState sets the entry's state to ns, notifying any waiters. +// +// The entry's expiration is bumped up to the greater of itself and the passed +// expiration; the zero value indicates immediate expiration, and is set +// unconditionally - this is an implementation detail that allows for entries +// to be reused. +func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { // Notify whoever is waiting on address resolution when transitioning - // out of 'incomplete'. - if e.s == incomplete { + // out of incomplete. + if e.s == incomplete && ns != incomplete { for w := range e.wakers { w.Assert() } e.wakers = nil - if e.done != nil { - close(e.done) + if ch := e.done; ch != nil { + close(ch) } + e.done = nil } - e.s = ns -} -func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) { - if w != nil { - e.wakers[w] = struct{}{} + if expiration.IsZero() || expiration.After(e.expiration) { + e.expiration = expiration } + e.s = ns } func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { @@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] - if ok { - s := entry.state() - if s != expired && entry.linkAddr == v { - // Disregard repeated calls. - return - } - // Check if entry is waiting for address resolution. - if s == incomplete { - entry.linkAddr = v - } else { - // Otherwise create a new entry to replace it. - entry = c.makeAndAddEntry(k, v) - } - } else { - entry = c.makeAndAddEntry(k, v) - } + // Calculate expiration time before acquiring the lock, since expiration is + // relative to the time when information was learned, rather than when it + // happened to be inserted into the cache. + expiration := time.Now().Add(c.ageLimit) - entry.changeState(ready) + c.cache.Lock() + entry := c.getOrCreateEntryLocked(k) + entry.linkAddr = v + + entry.changeState(ready, expiration) + c.cache.Unlock() } -// makeAndAddEntry is a helper function to create and add a new -// entry to the cache map and evict older entry as needed. -func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { - // Take over the next entry. - entry := &c.entries[c.next] - if c.cache[entry.addr] == entry { - delete(c.cache, entry.addr) +// getOrCreateEntryLocked retrieves a cache entry associated with k. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { + if entry, ok := c.cache.table[k]; ok { + c.cache.lru.Remove(entry) + c.cache.lru.PushFront(entry) + return entry } + var entry *linkAddrEntry + if len(c.cache.table) == linkAddrCacheSize { + entry = c.cache.lru.Back() - // Mark the soon-to-be-replaced entry as expired, just in case there is - // someone waiting for address resolution on it. - entry.changeState(expired) + delete(c.cache.table, entry.addr) + c.cache.lru.Remove(entry) - *entry = linkAddrEntry{ - addr: k, - linkAddr: v, - expiration: time.Now().Add(c.ageLimit), - wakers: make(map[*sleep.Waker]struct{}), - done: make(chan struct{}), + // Wake waiters and mark the soon-to-be-reused entry as expired. Note + // that the state passed doesn't matter when the zero time is passed. + entry.changeState(failed, time.Time{}) + } else { + entry = new(linkAddrEntry) } - c.cache[k] = entry - c.next = (c.next + 1) % len(c.entries) + *entry = linkAddrEntry{ + addr: k, + s: incomplete, + } + c.cache.table[k] = entry + c.cache.lru.PushFront(entry) return entry } @@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } } - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.cache[k]; ok { - switch s := entry.state(); s { - case expired: - case ready: - return entry.linkAddr, nil, nil - case failed: - return "", nil, tcpip.ErrNoLinkAddress - case incomplete: - // Address resolution is still in progress. - entry.maybeAddWaker(waker) - return "", entry.done, tcpip.ErrWouldBlock - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + c.cache.Lock() + defer c.cache.Unlock() + entry := c.getOrCreateEntryLocked(k) + switch s := entry.s; s { + case ready, failed: + if !time.Now().After(entry.expiration) { + // Not expired. + switch s { + case ready: + return entry.linkAddr, nil, nil + case failed: + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } - } - if linkRes == nil { - return "", nil, tcpip.ErrNoLinkAddress - } + entry.changeState(incomplete, time.Time{}) + fallthrough + case incomplete: + if waker != nil { + if entry.wakers == nil { + entry.wakers = make(map[*sleep.Waker]struct{}) + } + entry.wakers[waker] = struct{}{} + } - // Add 'incomplete' entry in the cache to mark that resolution is in progress. - e := c.makeAndAddEntry(k, "") - e.maybeAddWaker(waker) + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + } - go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + entry.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + } - return "", e.done, tcpip.ErrWouldBlock + return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } // removeWaker removes a waker previously added through get(). func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.mu.Lock() - defer c.mu.Unlock() + c.cache.Lock() + defer c.cache.Unlock() - if entry, ok := c.cache[k]; ok { + if entry, ok := c.cache.table[k]; ok { entry.removeWaker(waker) } } @@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) select { - case <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(k, i); stop { + case now := <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(now, k, i); stop { return } case <-done: @@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has succeeded // and mark the entry accordingly, e.g. ready, failed, etc. Return true if request // can stop, false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { + c.cache.Lock() + defer c.cache.Unlock() + entry, ok := c.cache.table[k] if !ok { // Entry was evicted from the cache. return true } - - switch s := entry.state(); s { - case ready, failed, expired: + switch s := entry.s; s { + case ready, failed: // Entry was made ready by resolver or failed. Either way we're done. - return true case incomplete: - if attempt+1 >= c.resolutionAttempts { - // Max number of retries reached, mark entry as failed. - entry.changeState(failed) - return true + if attempt+1 < c.resolutionAttempts { + // No response yet, need to send another ARP request. + return false } - // No response yet, need to send another ARP request. - return false + // Max number of retries reached, mark entry as failed. + entry.changeState(failed, now.Add(c.ageLimit)) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } + return true } func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - return &linkAddrCache{ + c := &linkAddrCache{ ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, - cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), } + c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 924f4d240..9946b8fe8 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "sync" + "sync/atomic" "testing" "time" @@ -29,25 +30,34 @@ type testaddr struct { linkAddr tcpip.LinkAddress } -var testaddrs []testaddr +var testAddrs = func() []testaddr { + var addrs []testaddr + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + addrs = append(addrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } + return addrs +}() type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration + cache *linkAddrCache + delay time.Duration + onLinkAddressRequest func() } func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - go func() { - if r.delay > 0 { - time.Sleep(r.delay) - } - r.fakeRequest(addr) - }() + time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) + if f := r.onLinkAddressRequest; f != nil { + f() + } return nil } func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testaddrs { + for _, ta := range testAddrs { if ta.addr.Addr == addr { r.cache.add(ta.addr, ta.linkAddr) break @@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } -func init() { - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - testaddrs = append(testaddrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } -} - func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testaddrs) - 1; i >= 0; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= 0; i-- { + e := testAddrs[i] c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { @@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) { } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { - e := testaddrs[i] + e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) @@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. - for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { + e := testAddrs[i] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } @@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) { for r := 0; r < 16; r++ { wg.Add(1) go func() { - for _, e := range testaddrs { + for _, e := range testAddrs { c.add(e.addr, e.linkAddr) c.get(e.addr, nil, "", nil, nil) // make work for gotsan } @@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } - e = testaddrs[0] + e = testAddrs[0] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } @@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) { func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { @@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) { func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) @@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) { func TestCacheResolution(t *testing.T) { c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testaddrs { + for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) @@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) { c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} + var requestCount uint32 + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) + } + // First, sanity check that resolution is working... - e := testaddrs[0] + e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } + before := atomic.LoadUint32(&requestCount) + e.addr.Addr += "2" if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } + + if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } } func TestCacheResolutionTimeout(t *testing.T) { @@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) { c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - e := testaddrs[0] + e := testAddrs[0] if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3e6ff4afb..89b4c5960 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add if list, ok := n.primary[protocol]; ok { for e := list.Front(); e != nil; e = e.Next() { ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { + if ref.kind == permanent && ref.tryIncRef() { r = ref break } @@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -186,82 +186,155 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) +} + +// getRefEpOrCreateTemp returns the referenced network endpoint for the given +// protocol and address. If none exists a temporary one may be created if +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - ref := n.endpoints[id] - if ref != nil && !ref.tryIncRef() { - ref = nil + + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.kind { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } + } + + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous + if !createTempEP { + for _, sn := range n.subnets { + if sn.Contains(address) { + createTempEP = true + break + } + } } - spoofing := n.spoofing + n.mu.RUnlock() - if ref != nil || !spoofing { - return ref + if !createTempEP { + return nil } // Try again with the lock in exclusive mode. If we still can't get the // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - ref = n.endpoints[id] - if ref == nil || !ref.tryIncRef() { - if netProto, ok := n.stack.networkProtocols[protocol]; ok { - addrWithPrefix := tcpip.AddressWithPrefix{address, netProto.DefaultPrefixLen()} - ref, _ = n.addAddressLocked(protocol, addrWithPrefix, peb, true) - if ref != nil { - ref.holdsInsertRef = false - } + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } - n.mu.Unlock() - return ref -} -func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - return nil, tcpip.ErrUnknownProtocol + n.mu.Unlock() + return nil } + ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb, temporary) - // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, addrWithPrefix, n.stack, n, n.linkEP) - if err != nil { - return nil, err - } + n.mu.Unlock() + return ref +} - id := *ep.ID() +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { - if !replace { + switch ref.kind { + case permanent: + // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.kind = permanent + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} - n.removeEndpointLocked(ref) +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + // Create the new network endpoint. + ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP) + if err != nil { + return nil, err + } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if _, ok := n.stack.linkAddrResolvers[protocol]; ok { + if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { ref.linkCache = n.stack } } n.endpoints[id] = ref - l, ok := n.primary[protocol] + l, ok := n.primary[protocolAddress.Protocol] if !ok { l = &ilist.List{} - n.primary[protocol] = l + n.primary[protocolAddress.Protocol] = l } switch peb { @@ -276,10 +349,10 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPre // AddAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error { +func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocol, addrWithPrefix, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err @@ -291,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.kind { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -356,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet { func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.kind == permanent { panic("Reference count dropped to zero before being removed") } @@ -381,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { + if r == nil || r.kind != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false - + r.kind = permanentExpired r.decRefLocked() return nil @@ -398,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -414,8 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()} - if _, err := n.addAddressLocked(protocol, addrWithPrefix, NeverPrimaryEndpoint, false); err != nil { + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -437,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -445,6 +531,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { + r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) + r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the physical interface. @@ -472,6 +565,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) + n.stack.AddLinkAddress(n.id, src, remote) + // If the packet is destined to the IPv4 Broadcast address, then make a // route to each IPv4 network endpoint and let each endpoint handle the // packet. @@ -479,11 +574,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr // n.endpoints is mutex protected so acquire lock. n.mu.RLock() for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) } } n.mu.RUnlock() @@ -491,10 +583,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr } if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return } @@ -517,8 +606,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -543,52 +633,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.stats.IP.InvalidAddressesReceived.Increment() } -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - id := NetworkEndpointID{dst} - - n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - - promiscuous := n.promiscuous - // Check if the packet is for a subnet this NIC cares about. - if !promiscuous { - for _, sn := range n.subnets { - if sn.Contains(dst) { - promiscuous = true - break - } - } - } - n.mu.RUnlock() - if promiscuous { - // Try again with the lock in exclusive mode. If we still can't - // get the endpoint, create a new "temporary" one. It will only - // exist while there's a route through it. - n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref - } - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - n.mu.Unlock() - return nil - } - addrWithPrefix := tcpip.AddressWithPrefix{dst, netProto.DefaultPrefixLen()} - ref, err := n.addAddressLocked(protocol, addrWithPrefix, CanBePrimaryEndpoint, true) - n.mu.Unlock() - if err == nil { - ref.holdsInsertRef = false - return ref - } - } - - return nil -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { @@ -676,9 +720,33 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +type networkEndpointKind int + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -687,11 +755,25 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + kind networkEndpointKind +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.kind != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.kind != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 391ab4344..e52cdd674 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. func (r *Route) IsResolutionRequired() bool { - return r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() @@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 57b8a9994..d69162ba1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" @@ -333,6 +334,15 @@ type TCPEndpointState struct { Sender TCPSenderState } +// ResumableEndpoint is an endpoint that needs to be resumed after restore. +type ResumableEndpoint interface { + // Resume resumes an endpoint after restore. This can be used to restart + // background workers such as protocol goroutines. This must be called after + // all indirect dependencies of the endpoint has been restored, which + // generally implies at the end of the restore process. + Resume(*Stack) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -372,6 +382,13 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool + + // tables are the iptables packet filtering and manipulation rules. + tables iptables.IPTables + + // resumableEndpoints is a list of endpoints that need to be resumed if the + // stack is being restored. + resumableEndpoints []ResumableEndpoint } // Options contains optional Stack configuration. @@ -751,10 +768,10 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } -// AddAddressWithPrefix adds a new network-layer address/prefixLen to the +// AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix) *tcpip.Error { - return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, CanBePrimaryEndpoint) +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { + return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) } // AddAddressWithOptions is the same as AddAddress, but allows you to specify @@ -764,13 +781,18 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt if !ok { return tcpip.ErrUnknownProtocol } - addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()} - return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, peb) -} - -// AddAddressWithPrefixAndOptions is the same as AddAddressWithPrefixLen, -// but allows you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error { + return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb) +} + +// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows +// you to specify whether the new endpoint can be primary or not. +func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -779,7 +801,7 @@ func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.Ne return tcpip.ErrUnknownNICID } - return nic.AddAddress(protocol, addrWithPrefix, peb) + return nic.AddAddress(protocolAddress, peb) } // AddSubnet adds a subnet range to the specified NIC. @@ -873,7 +895,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { @@ -1082,6 +1104,28 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip } } +// RegisterRestoredEndpoint records e as an endpoint that has been restored on +// this stack. +func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { + s.mu.Lock() + s.resumableEndpoints = append(s.resumableEndpoints, e) + s.mu.Unlock() +} + +// Resume restarts the stack after a restore. This must be called after the +// entire system has been restored. +func (s *Stack) Resume() { + // ResumableEndpoint.Resume() may call other methods on s, so we can't hold + // s.mu while resuming the endpoints. + s.mu.Lock() + eps := s.resumableEndpoints + s.resumableEndpoints = nil + s.mu.Unlock() + for _, e := range eps { + e.Resume(s) + } +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1161,3 +1205,13 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC } return tcpip.ErrUnknownNICID } + +// IPTables returns the stack's iptables. +func (s *Stack) IPTables() iptables.IPTables { + return s.tables +} + +// SetIPTables sets the stack's iptables. +func (s *Stack) SetIPTables(ipt iptables.IPTables) { + s.tables = ipt +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9d082bba4..4debd1eec 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -181,6 +181,10 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } +func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { + return f.packetCount[int(intfAddr)%len(f.packetCount)] +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } @@ -188,7 +192,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, - id: stack.NetworkEndpointID{addrWithPrefix.Address}, + id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, @@ -289,16 +293,75 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatal("FindRoute failed:", err) + return err } defer r.Release() + return send(r, payload) +} +func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Error("WritePacket failed:", err) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) +} + +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := sendTo(s, addr, payload); err != nil { + t.Error("sendTo failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("sendTo packet count: got = %d, want %d", got, want) + } +} + +func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { + t.Helper() + linkEP.Drain() + if err := send(r, payload); err != nil { + t.Error("send failed:", err) + } + if got, want := linkEP.Drain(), 1; got != want { + t.Errorf("send packet count: got = %d, want %d", got, want) + } +} + +func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := send(r, payload); gotErr != wantErr { + t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := sendTo(s, addr, payload); gotErr != wantErr { + t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + 1 + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we do NOT expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) +} + +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { + t.Helper() + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if got := fakeNet.PacketCount(localAddrByte); got != want { + t.Errorf("receive packet count: got = %d, want %d", got, want) } } @@ -312,17 +375,20 @@ func TestNetworkSend(t *testing.T) { t.Fatal("NewNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) } // Make sure that the link-layer endpoint received the outbound packet. - sendTo(t, s, "\x03", nil) - if c := linkEP.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x03", linkEP, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -360,24 +426,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { // Set a route table that sends all packets with odd destination // addresses through the first NIC, and all even destination address // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + }) + } // Send a packet to an odd destination. - sendTo(t, s, "\x05", nil) - - if c := linkEP1.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x05", linkEP1, nil) // Send a packet to an even destination. - sendTo(t, s, "\x06", nil) - - if c := linkEP2.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x06", linkEP2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -439,10 +507,20 @@ func TestRoutes(t *testing.T) { // Set a route table that sends all packets with odd destination // addresses through the first NIC, and all even destination address // through the second one. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\x01", "\x00", 1}, - {"\x00", "\x01", "\x00", 2}, - }) + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + }) + } // Test routes to odd address. testRoute(t, s, 0, "", "\x05", "\x01") @@ -472,6 +550,10 @@ func TestRoutes(t *testing.T) { } func TestAddressRemoval(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -479,99 +561,285 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet doesn't get delivered - // anymore. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } -func TestDelayedRemovalDueToRoute(t *testing.T) { +func TestAddressRemovalWithRouteHeld(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } - - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) - } - - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 2 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) - // Remove the address, then check that packet is still deliverable - // because the route is keeping the address alive. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } +} - // Release the route, then check that packet is not deliverable anymore. - r.Release() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) +func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { + t.Helper() + info, ok := s.NICInfo()[nicid] + if !ok { + t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + } + if len(addr) == 0 { + // No address given, verify that there is no address assigned to the NIC. + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + } + } + return + } + // Address given, verify the address is assigned to the NIC and no other + // address is. + found := false + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber { + if a.AddressWithPrefix.Address == addr { + found = true + } else { + t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) + } + } + } + if !found { + t.Errorf("verify address: couldn't find %s on the NIC", addr) + } +} + +func TestEndpointExpiration(t *testing.T) { + const ( + localAddrByte byte = 0x01 + remoteAddr tcpip.Address = "\x03" + noAddr tcpip.Address = "" + nicid tcpip.NICID = 1 + ) + localAddr := tcpip.Address([]byte{localAddrByte}) + + for _, promiscuous := range []bool{true, false} { + for _, spoofing := range []bool{true, false} { + t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) + buf[0] = localAddrByte + + if promiscuous { + if err := s.SetPromiscuousMode(nicid, true); err != nil { + t.Fatal("SetPromiscuousMode failed:", err) + } + } + + if spoofing { + if err := s.SetSpoofing(nicid, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + } + + // 1. No Address yet, send should only work for spoofing, receive for + // promiscuous mode. + //----------------------- + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 2. Add Address, everything should work. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 3. Remove the address, send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 4. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 5. Take a reference to the endpoint by getting a route. Verify that + // we can still send/receive, including sending using the route. + //----------------------- + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 6. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + testSend(t, r, linkEP, nil) + testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + + // 7. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // 8. Remove the route, sendTo/recv should still work. + //----------------------- + r.Release() + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testSendTo(t, s, remoteAddr, linkEP, nil) + + // 9. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, linkEP, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) + } else { + testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + } + }) + } } } @@ -583,9 +851,13 @@ func TestPromiscuousMode(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -593,22 +865,15 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + const localAddrByte byte = 0x01 + buf[0] = localAddrByte + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testRecv(t, fakeNet, localAddrByte, linkEP, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -621,54 +886,120 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) +} - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) +func TestSpoofingWithAddress(t *testing.T) { + localAddr := tcpip.Address("\x01") + nonExistentLocalAddr := tcpip.Address("\x02") + dstAddr := tcpip.Address("\x03") + + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } + // Sending a packet works. + testSendTo(t, s, dstAddr, linkEP, nil) + testSend(t, r, linkEP, nil) + + // FindRoute should also work with a local address that exists on the NIC. + r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != localAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } + // Sending a packet using the route works. + testSend(t, r, linkEP, nil) } -func TestAddressSpoofing(t *testing.T) { - srcAddr := tcpip.Address("\x01") +func TestSpoofingNoAddress(t *testing.T) { + nonExistentLocalAddr := tcpip.Address("\x01") dstAddr := tcpip.Address("\x02") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") + id, linkEP := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { - t.Fatal("AddAddress failed:", err) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) - // With address spoofing disabled, FindRoute does not permit an address // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err == nil { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } + // Sending a packet fails. + testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. if err := s.SetSpoofing(1, true); err != nil { t.Fatal("SetSpoofing failed:", err) } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - if r.LocalAddress != srcAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } + // Sending a packet works. + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, linkEP, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { @@ -806,16 +1137,20 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -824,9 +1159,52 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + testRecv(t, fakeNet, localAddrByte, linkEP, buf) +} + +// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +func TestCheckLocalAddressForSubnet(t *testing.T) { + const nicID tcpip.NICID = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) + } + + subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) + + if err != nil { + t.Fatal("NewSubnet failed:", err) + } + if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddSubnet failed:", err) + } + + // Loop over all subnet addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + addr := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) + } + addr[0]++ + } + + // Trying the next address should fail since it is outside the subnet range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) } } @@ -839,16 +1217,20 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - s.SetRouteTable([]tcpip.Route{ - {"\x00", "\x00", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -856,10 +1238,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddSubnet failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) } func TestNetworkOptions(t *testing.T) { @@ -969,12 +1348,18 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) if behavior == stack.CanBePrimaryEndpoint { - addressWithPrefix := tcpip.AddressWithPrefix{address, addrLen * 8} - if err := s.AddAddressWithPrefixAndOptions(1, fakeNetNumber, addressWithPrefix, behavior); err != nil { - t.Fatal("AddAddressWithPrefixAndOptions failed:", err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: addrLen * 8, + }, + } + if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil { + t.Fatal("AddProtocolAddressWithOptions failed:", err) } // Remember the address/prefix. - primaryAddrAdded[addressWithPrefix] = struct{}{} + primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil { t.Fatal("AddAddressWithOptions failed:", err) @@ -1024,20 +1409,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116}, } { t.Run(tc.name, func(t *testing.T) { - addressWithPrefix := tcpip.AddressWithPrefix{tc.address, tc.prefixLen} - - if err := s.AddAddressWithPrefix(1, fakeNetNumber, addressWithPrefix); err != nil { - t.Fatal("AddAddressWithPrefix failed:", err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tc.address, + PrefixLen: tc.prefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddress); err != nil { + t.Fatal("AddProtocolAddress failed:", err) } // Check that we get the right initial address and prefix length. if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != addressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, addressWithPrefix) + } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) } - if err := s.RemoveAddress(1, addressWithPrefix.Address); err != nil { + if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } @@ -1102,7 +1492,7 @@ func TestAddAddress(t *testing.T) { verifyAddresses(t, expectedAddresses, gotAddresses) } -func TestAddAddressWithPrefix(t *testing.T) { +func TestAddProtocolAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") @@ -1116,14 +1506,17 @@ func TestAddAddressWithPrefix(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) for _, addrLen := range addrLenRange { for _, prefixLen := range prefixLenRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithPrefix(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}); err != nil { - t.Errorf("AddAddressWithPrefix(address=%s, prefixLen=%d) failed: %s", address, prefixLen, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addrGen.next(addrLen), + PrefixLen: prefixLen, + }, } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, - }) + if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) + } + expectedAddresses = append(expectedAddresses, protocolAddress) } } @@ -1160,7 +1553,7 @@ func TestAddAddressWithOptions(t *testing.T) { verifyAddresses(t, expectedAddresses, gotAddresses) } -func TestAddAddressWithPrefixAndOptions(t *testing.T) { +func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") @@ -1176,14 +1569,17 @@ func TestAddAddressWithPrefixAndOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, prefixLen := range prefixLenRange { for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithPrefixAndOptions(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}, behavior); err != nil { - t.Fatalf("AddAddressWithPrefixAndOptions(address=%s, prefixLen=%d, behavior=%d) failed: %s", address, prefixLen, behavior, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addrGen.next(addrLen), + PrefixLen: prefixLen, + }, } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen}, - }) + if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) + } + expectedAddresses = append(expectedAddresses, protocolAddress) } } } @@ -1196,15 +1592,19 @@ func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id1, linkEP1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id1); err != nil { - t.Fatal("CreateNIC failed:", err) + t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) } // Route all packets for address \x01 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x01", "\xff", "\x00", 1}, - }) + { + subnet, err := tcpip.NewSubnet("\x01", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } // Send a packet to address 1. buf := buffer.NewView(30) @@ -1219,7 +1619,9 @@ func TestNICStats(t *testing.T) { payload := buffer.NewView(10) // Write a packet out via the address for NIC 1 - sendTo(t, s, "\x01", payload) + if err := sendTo(s, "\x01", payload); err != nil { + t.Fatal("sendTo failed: ", err) + } want := uint64(linkEP1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) @@ -1253,9 +1655,13 @@ func TestNICForwarding(t *testing.T) { } // Route all packets to address 3 to NIC 2. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - }) + { + subnet, err := tcpip.NewSubnet("\x03", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) + } // Send a packet to address 3. buf := buffer.NewView(30) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index b418db046..5335897f5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -64,7 +65,7 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } @@ -78,10 +79,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } -func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -104,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr @@ -200,6 +206,13 @@ func (f *fakeTransportEndpoint) State() uint32 { func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { } +func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { + return iptables.IPTables{}, nil +} + +func (f *fakeTransportEndpoint) Resume(*stack.Stack) { +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool @@ -271,7 +284,13 @@ func TestTransportReceive(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress failed: %v", err) @@ -327,7 +346,13 @@ func TestTransportControlReceive(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress failed: %v", err) @@ -393,7 +418,13 @@ func TestTransportSend(t *testing.T) { t.Fatalf("AddAddress failed: %v", err) } - s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } // Create endpoint and bind it. wq := waiter.Queue{} @@ -484,10 +515,20 @@ func TestTransportForwarding(t *testing.T) { // Route all packets to address 3 to NIC 2 and all packets to address // 1 to NIC 1. - s.SetRouteTable([]tcpip.Route{ - {"\x03", "\xff", "\x00", 2}, - {"\x01", "\xff", "\x00", 1}, - }) + { + subnet0, err := tcpip.NewSubnet("\x03", "\xff") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet0, Gateway: "\x00", NIC: 2}, + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + }) + } wq := waiter.Queue{} ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4208c0303..8f9b86cce 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -31,6 +31,7 @@ package tcpip import ( "errors" "fmt" + "math/bits" "reflect" "strconv" "strings" @@ -39,6 +40,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" ) @@ -144,8 +146,17 @@ type Address string type AddressMask string // String implements Stringer. -func (a AddressMask) String() string { - return Address(a).String() +func (m AddressMask) String() string { + return Address(m).String() +} + +// Prefix returns the number of bits before the first host bit. +func (m AddressMask) Prefix() int { + p := 0 + for _, b := range []byte(m) { + p += bits.LeadingZeros8(^b) + } + return p } // Subnet is a subnet defined by its address and mask. @@ -167,6 +178,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) { return Subnet{a, m}, nil } +// String implements Stringer. +func (s Subnet) String() string { + return fmt.Sprintf("%s/%d", s.ID(), s.Prefix()) +} + // Contains returns true iff the address is of the same length and matches the // subnet address and mask. func (s *Subnet) Contains(a Address) bool { @@ -189,28 +205,13 @@ func (s *Subnet) ID() Address { // Bits returns the number of ones (network bits) and zeros (host bits) in the // subnet mask. func (s *Subnet) Bits() (ones int, zeros int) { - for _, b := range []byte(s.mask) { - for i := uint(0); i < 8; i++ { - if b&(1<<i) == 0 { - zeros++ - } else { - ones++ - } - } - } - return + ones = s.mask.Prefix() + return ones, len(s.mask)*8 - ones } // Prefix returns the number of bits before the first host bit. func (s *Subnet) Prefix() int { - for i, b := range []byte(s.mask) { - for j := 7; j >= 0; j-- { - if b&(1<<uint(j)) == 0 { - return i*8 + 7 - j - } - } - } - return len(s.mask) * 8 + return s.mask.Prefix() } // Mask returns the subnet mask. @@ -328,12 +329,12 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error) + Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. - Peek([][]byte) (uintptr, ControlMessages, *Error) + Peek([][]byte) (int64, ControlMessages, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -352,6 +353,9 @@ type Endpoint interface { // ErrAddressFamilyNotSupported must be returned. Connect(address FullAddress) *Error + // Disconnect disconnects the endpoint from its peer. + Disconnect() *Error + // Shutdown closes the read and/or write end of the endpoint connection // to its peer. Shutdown(flags ShutdownFlags) *Error @@ -403,6 +407,9 @@ type Endpoint interface { // // NOTE: This method is a no-op for sockets other than TCP. ModerateRecvBuf(copied int) + + // IPTables returns the iptables for this endpoint's stack. + IPTables() (iptables.IPTables, error) } // WriteOptions contains options for Endpoint.Write. @@ -563,13 +570,8 @@ type BroadcastOption int // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. type Route struct { - // Destination is the address that must be matched against the masked - // target address to check if this row is viable. - Destination Address - - // Mask specifies which bits of the Destination and the target address - // must match for this row to be viable. - Mask AddressMask + // Destination must contain the target address for this row to be viable. + Destination Subnet // Gateway is the gateway to be used if this row is viable. Gateway Address @@ -578,25 +580,15 @@ type Route struct { NIC NICID } -// Match determines if r is viable for the given destination address. -func (r *Route) Match(addr Address) bool { - if len(addr) != len(r.Destination) { - return false - } - - // Using header.Ipv4Broadcast would introduce an import cycle, so - // we'll use a literal instead. - if addr == "\xff\xff\xff\xff" { - return true - } - - for i := 0; i < len(r.Destination); i++ { - if (addr[i] & r.Mask[i]) != r.Destination[i] { - return false - } +// String implements the fmt.Stringer interface. +func (r Route) String() string { + var out strings.Builder + fmt.Fprintf(&out, "%s", r.Destination) + if len(r.Gateway) > 0 { + fmt.Fprintf(&out, " via %s", r.Gateway) } - - return true + fmt.Fprintf(&out, " nic %d", r.NIC) + return out.String() } // LinkEndpointID represents a data link layer endpoint. @@ -1068,6 +1060,11 @@ type AddressWithPrefix struct { PrefixLen int } +// String implements the fmt.Stringer interface. +func (a AddressWithPrefix) String() string { + return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { @@ -1078,11 +1075,13 @@ type ProtocolAddress struct { AddressWithPrefix AddressWithPrefix } -// danglingEndpointsMu protects access to danglingEndpoints. -var danglingEndpointsMu sync.Mutex +var ( + // danglingEndpointsMu protects access to danglingEndpoints. + danglingEndpointsMu sync.Mutex -// danglingEndpoints tracks all dangling endpoints no longer owned by the app. -var danglingEndpoints = make(map[Endpoint]struct{}) + // danglingEndpoints tracks all dangling endpoints no longer owned by the app. + danglingEndpoints = make(map[Endpoint]struct{}) +) // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index ebb1c1b56..fb3a0a5ee 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -60,12 +60,12 @@ func TestSubnetBits(t *testing.T) { }{ {"\x00", 0, 8}, {"\x00\x00", 0, 16}, - {"\x36", 4, 4}, - {"\x5c", 4, 4}, - {"\x5c\x5c", 8, 8}, - {"\x5c\x36", 8, 8}, - {"\x36\x5c", 8, 8}, - {"\x36\x36", 8, 8}, + {"\x36", 0, 8}, + {"\x5c", 0, 8}, + {"\x5c\x5c", 0, 16}, + {"\x5c\x36", 0, 16}, + {"\x36\x5c", 0, 16}, + {"\x36\x36", 0, 16}, {"\xff", 8, 0}, {"\xff\xff", 16, 0}, } @@ -122,26 +122,6 @@ func TestSubnetCreation(t *testing.T) { } } -func TestRouteMatch(t *testing.T) { - tests := []struct { - d Address - m AddressMask - a Address - want bool - }{ - {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, - {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, - {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, - {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, - } - for _, tt := range tests { - r := Route{Destination: tt.d, Mask: tt.m} - if got := r.Match(tt.a); got != tt.want { - t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want) - } - } -} - func TestAddressString(t *testing.T) { for _, want := range []string{ // Taken from stdlib. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 62182a3e6..d78a162b8 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index ba6671c26..451d3880e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -130,6 +131,11 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// IPTables implements tcpip.Endpoint.IPTables. +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -199,7 +205,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -301,11 +307,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -422,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - nicid := addr.NIC localPort := uint16(0) switch e.state { diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 99b8c4093..c587b96b6 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -63,7 +63,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s if e.state != stateBound && e.state != stateConnected { return @@ -73,7 +78,7 @@ func (e *endpoint) afterLoad() { if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) if err != nil { - panic(*err) + panic(err) } e.id.LocalAddress = e.route.LocalAddress @@ -85,6 +90,6 @@ func (e *endpoint) afterLoad() { e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) if err != nil { - panic(*err) + panic(err) } } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index bc4b255b4..7241f6c19 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b633cd9d8..13e17e2a6 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -168,6 +169,11 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} +// IPTables implements tcpip.Endpoint.IPTables. +func (ep *endpoint) IPTables() (iptables.IPTables, error) { + return ep.stack.IPTables(), nil +} + // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !ep.associated { @@ -201,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -305,7 +311,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -335,24 +341,24 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintpt return 0, nil, tcpip.ErrUnknownProtocol } - return uintptr(len(payloadBytes)), nil, nil + return int64(len(payloadBytes)), nil, nil } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect implements tcpip.Endpoint.Connect. func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if addr.Addr == "" { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - if ep.closed { return tcpip.ErrInvalidEndpointState } @@ -484,7 +490,7 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return nil + return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index cb5534d90..168953dec 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -63,19 +63,23 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { - // StackFromEnv is a stack used specifically for save/restore. - ep.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(ep) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (ep *endpoint) Resume(s *stack.Stack) { + ep.stack = s - // If the endpoint is connected, re-connect via the save/restore stack. + // If the endpoint is connected, re-connect. if ep.connected { var err *tcpip.Error ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) if err != nil { - panic(*err) + panic(err) } } - // If the endpoint is bound, re-bind via the save/restore stack. + // If the endpoint is bound, re-bind. if ep.bound { if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { panic(tcpip.ErrBadLocalAddress) @@ -83,6 +87,6 @@ func (ep *endpoint) afterLoad() { } if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(*err) + panic(err) } } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4cd25e8e2..1ee1a53f8 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 52fd1bfa3..e9c5099ea 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -96,6 +96,17 @@ type listenContext struct { hasher hash.Hash v6only bool netProto tcpip.NetworkProtocolNumber + // pendingMu protects pendingEndpoints. This should only be accessed + // by the listening endpoint's worker goroutine. + // + // Lock Ordering: listenEP.workerMu -> pendingMu + pendingMu sync.Mutex + // pending is used to wait for all pendingEndpoints to finish when + // a socket is closed. + pending sync.WaitGroup + // pendingEndpoints is a map of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[stack.TransportEndpointID]*endpoint } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -133,14 +144,15 @@ func decSynRcvdCount() { } // newListenContext creates a new listen context. -func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stack, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6only: v6only, - netProto: netProto, - listenEP: listenEP, + stack: stk, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + listenEP: listenEP, + pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } rand.Read(l.nonce[0][:]) @@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } + // listenEP is nil when listenContext is used by tcp.Forwarder. + if l.listenEP != nil { + l.listenEP.mu.Lock() + if l.listenEP.state != StateListen { + l.listenEP.mu.Unlock() + return nil, tcpip.ErrConnectionAborted + } + l.addPendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) @@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err := h.execute(); err != nil { ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } return nil, err } ep.mu.Lock() @@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return ep, nil } +func (l *listenContext) addPendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + l.pendingEndpoints[n.id] = n + l.pending.Add(1) + l.pendingMu.Unlock() +} + +func (l *listenContext) removePendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + delete(l.pendingEndpoints, n.id) + l.pending.Done() + l.pendingMu.Unlock() +} + +func (l *listenContext) closeAllPendingEndpoints() { + l.pendingMu.Lock() + for _, n := range l.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + l.pendingMu.Unlock() + l.pending.Wait() +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state, the new endpoint is closed // instead. func (e *endpoint) deliverAccepted(n *endpoint) { - e.mu.RLock() + e.mu.Lock() state := e.state - e.mu.RUnlock() + e.pendingAccepted.Add(1) + defer e.pendingAccepted.Done() + acceptedChan := e.acceptedChan + e.mu.Unlock() if state == StateListen { - e.acceptedChan <- n + acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { n.Close() @@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() return } - + ctx.removePendingEndpoint(n) e.deliverAccepted(n) } @@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + defer func() { // Mark endpoint as closed. This will prevent goroutines running // handleSynSegment() from attempting to queue new connections @@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() e.state = StateClose + // close any endpoints in SYN-RCVD state. + ctx.closeAllPendingEndpoints() + // Do cleanup if needed. e.completeWorkerLocked() @@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) }() - e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) - s := sleep.Sleeper{} s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) @@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.handleListenSegment(ctx, s) s.decRef() } - synRcvdCount.pending.Wait() close(e.drainDone) <-e.undrain } diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index d9f79e8c5..c54610a87 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -570,3 +570,89 @@ func TestV4AcceptOnV4(t *testing.T) { // Test acceptance. testV4Accept(t, c) } + +func testV4ListenClose(t *testing.T, c *context.Context) { + // Set the SynRcvd threshold to zero to force a syn cookie based accept + // to happen. + saved := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = saved + }() + tcp.SynRcvdCountThreshold = 0 + const n = uint16(32) + + // Start listening. + if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + irs := seqnum.Value(789) + for i := uint16(0); i < n; i++ { + // Send a SYN request. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + i, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } + + // Each of these ACK's will cause a syn-cookie based connection to be + // accepted and delivered to the listening endpoint. + for i := uint16(0); i < n; i++ { + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + } + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + nep.Close() + c.EP.Close() +} + +func TestV4ListenCloseOnV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4ListenClose(t, c) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cc49c8272..ac927569a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tmutex" @@ -361,6 +362,12 @@ type endpoint struct { // without hearing a response, the connection is closed. keepalive keepalive + // pendingAccepted is a synchronization primitive used to track number + // of connections that are queued up to be delivered to the accepted + // channel. We use this to ensure that all goroutines blocked on writing + // to the acceptedChan below terminate before we close acceptedChan. + pendingAccepted sync.WaitGroup `state:"nosave"` + // acceptedChan is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. @@ -374,7 +381,11 @@ type endpoint struct { // The goroutine drain completion notification channel. drainDone chan struct{} `state:"nosave"` - // The goroutine undrain notification channel. + // The goroutine undrain notification channel. This is currently used as + // a way to block the worker goroutines. Today nothing closes/writes + // this channel and this causes any goroutines waiting on this to just + // block. This is used during save/restore to prevent worker goroutines + // from mutating state as it's being saved. undrain chan struct{} `state:"nosave"` // probe if not nil is invoked on every received segment. It is passed @@ -574,6 +585,34 @@ func (e *endpoint) Close() { e.mu.Unlock() } +// closePendingAcceptableConnections closes all connections that have completed +// handshake but not yet been delivered to the application. +func (e *endpoint) closePendingAcceptableConnectionsLocked() { + done := make(chan struct{}) + // Spin a goroutine up as ranging on e.acceptedChan will just block when + // there are no more connections in the channel. Using a non-blocking + // select does not work as it can potentially select the default case + // even when there are pending writes but that are not yet written to + // the channel. + go func() { + defer close(done) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + }() + // pendingAccepted(see endpoint.deliverAccepted) tracks the number of + // endpoints which have completed handshake but are not yet written to + // the e.acceptedChan. We wait here till the goroutine above can drain + // all such connections from e.acceptedChan. + e.pendingAccepted.Wait() + close(e.acceptedChan) + <-done + e.acceptedChan = nil +} + // cleanupLocked frees all resources associated with the endpoint. It is called // after Close() is called and the worker goroutine (if any) is done with its // work. @@ -581,14 +620,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { - close(e.acceptedChan) - for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() - n.Close() - } - e.acceptedChan = nil + e.closePendingAcceptableConnectionsLocked() } e.workerCleanup = false @@ -683,6 +715,11 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// IPTables implements tcpip.Endpoint.IPTables. +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() @@ -740,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { return v, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. - - e.mu.RLock() - defer e.mu.RUnlock() - +// isEndpointWritableLocked checks if a given endpoint is writable +// and also returns the number of bytes that can be written at this +// moment. If the endpoint is not writable then it returns an error +// indicating the reason why it's not writable. +// Caller must hold e.mu and e.sndBufMu +func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. if !e.state.connected() { switch e.state { case StateError: - return 0, nil, e.hardError + return 0, e.hardError default: - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } } - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil - } - - e.sndBufMu.Lock() - // Check if the connection has already been closed for sends. if e.sndClosed { - e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } - // Check against the limit. avail := e.sndBufSize - e.sndBufUsed if avail <= 0 { + return 0, tcpip.ErrWouldBlock + } + return avail, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + e.sndBufMu.Lock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrWouldBlock + e.mu.RUnlock() + return 0, nil, err } + e.sndBufMu.Unlock() + e.mu.RUnlock() + + // Nothing to do if the buffer is empty. + if p.Size() == 0 { + return 0, nil, nil + } + + // Copy in memory without holding sndBufMu so that worker goroutine can + // make progress independent of this operation. v, perr := p.Get(avail) if perr != nil { - e.sndBufMu.Unlock() return 0, nil, perr } - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a + // write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due to a + // simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } // Add data to the send queue. + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) e.sndBufUsed += l e.sndBufInQueue += seqnum.Size(l) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() if e.workMu.TryLock() { // Do the work inline. @@ -803,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return uintptr(l), nil, nil + return int64(l), nil, nil } // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -835,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // Make a copy of vec so we can modify the slide headers. vec = append([][]byte(nil), vec...) - var num uintptr - + var num int64 for s := e.rcvList.Front(); s != nil; s = s.Next() { views := s.data.Views() @@ -855,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er n := copy(vec[0], v) v = v[n:] vec[0] = vec[0][n:] - num += uintptr(n) + num += int64(n) } } } @@ -1277,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol netProto = header.IPv4ProtocolNumber addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { + if addr.Addr == header.IPv4Any { addr.Addr = "" } } @@ -1291,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol return netProto, nil } +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - if addr.Addr == "" && addr.Port == 0 { - // AF_UNSPEC isn't supported. - return tcpip.ErrAddressFamilyNotSupported - } - return e.connect(addr, true, true) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b3f0f6c5d..831389ec7 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -165,7 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() @@ -197,14 +202,13 @@ func (e *endpoint) afterLoad() { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { + e.connectingAddress = e.id.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress - } else { - e.connectingAddress = e.id.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 0fee7ab72..735edfe55 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -39,6 +39,28 @@ const ( nDupAckThreshold = 3 ) +// ccState indicates the current congestion control state for this sender. +type ccState int + +const ( + // Open indicates that the sender is receiving acks in order and + // no loss or dupACK's etc have been detected. + Open ccState = iota + // RTORecovery indicates that an RTO has occurred and the sender + // has entered an RTO based recovery phase. + RTORecovery + // FastRecovery indicates that the sender has entered FastRecovery + // based on receiving nDupAck's. This state is entered only when + // SACK is not in use. + FastRecovery + // SACKRecovery indicates that the sender has entered SACK based + // recovery. + SACKRecovery + // Disorder indicates the sender either received some SACK blocks + // or dupACK's. + Disorder +) + // congestionControl is an interface that must be implemented by any supported // congestion control algorithm. type congestionControl interface { @@ -138,6 +160,9 @@ type sender struct { // maxSentAck is the maxium acknowledgement actually sent. maxSentAck seqnum.Value + // state is the current state of congestion control for this endpoint. + state ccState + // cc is the congestion control algorithm in use for this sender. cc congestionControl } @@ -435,6 +460,7 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveFastRecovery() } + s.state = RTORecovery s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and @@ -638,7 +664,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se segEnd = seg.sequenceNumber.Add(1) // Transition to FIN-WAIT1 state since we're initiating an active close. s.ep.mu.Lock() - s.ep.state = StateFinWait1 + switch s.ep.state { + case StateCloseWait: + // We've already received a FIN and are now sending our own. The + // sender is now awaiting a final ACK for this FIN. + s.ep.state = StateLastAck + default: + s.ep.state = StateFinWait1 + } s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. @@ -820,9 +853,11 @@ func (s *sender) enterFastRecovery() { s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding if s.ep.sackPermitted { + s.state = SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() return } + s.state = FastRecovery s.ep.stack.Stats().TCP.FastRecovery.Increment() } @@ -981,6 +1016,7 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { s.fr.highRxt = s.sndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() + s.state = Disorder return false } @@ -1112,6 +1148,9 @@ func (s *sender) handleRcvdSegment(seg *segment) { // window based on the number of acknowledged packets. if !s.fr.active { s.cc.Update(originalOutstanding - s.outstanding) + if s.fr.last.LessThan(s.sndUna) { + s.state = Open + } } // It is possible for s.outstanding to drop below zero if we get diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 915a98047..f79b8ec5f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2874,15 +2874,11 @@ func makeStack() (*stack.Stack, *tcpip.Error) { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index bcc0f3e28..272481aa0 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -168,15 +168,11 @@ func New(t *testing.T, mtu uint32) *Context { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 6dac66b50..ac2666f69 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/waiter", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 91f89a781..ac5905772 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -172,6 +173,11 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// IPTables implements tcpip.Endpoint.IPTables. +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil +} + // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -241,13 +247,13 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto, err := e.checkV4Mapped(&addr, false) - if err != nil { - return stack.Route{}, 0, 0, err +func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { + localAddr := e.id.LocalAddress + if isBroadcastOrMulticast(localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" } - localAddr := e.id.LocalAddress if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { if nicid == 0 { nicid = e.multicastNICID @@ -260,14 +266,14 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac // Find a route to the desired destination. r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { - return stack.Route{}, 0, 0, err + return stack.Route{}, 0, err } - return r, nicid, netProto, nil + return r, nicid, nil } // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -336,7 +342,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - r, _, _, err := e.connectRoute(nicid, *to) + netProto, err := e.checkV4Mapped(to, false) + if err != nil { + return 0, nil, err + } + + r, _, err := e.connectRoute(nicid, *to, netProto) if err != nil { return 0, nil, err } @@ -368,11 +379,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { return 0, nil, err } - return uintptr(len(v)), nil, nil + return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -442,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } nicID := v.NIC - if v.InterfaceAddr == header.IPv4Any { + + // The interface address is considered not-set if it is empty or contains + // all-zeros. The former represent the zero-value in golang, the latter the + // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. + allZeros := header.IPv4Any + if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { if nicID == 0 { r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) if err == nil { @@ -686,7 +702,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t netProto = header.IPv4ProtocolNumber addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { + if addr.Addr == header.IPv4Any { addr.Addr = "" } @@ -705,7 +721,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } -func (e *endpoint) disconnect() *tcpip.Error { +// Disconnect implements tcpip.Endpoint.Disconnect. +func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -740,8 +757,9 @@ func (e *endpoint) disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - if addr.Addr == "" { - return e.disconnect() + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err } if addr.Port == 0 { // We don't support connecting to port zero. @@ -770,7 +788,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - r, nicid, netProto, err := e.connectRoute(nicid, addr) + r, nicid, err := e.connectRoute(nicid, addr, netProto) if err != nil { return err } @@ -906,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicid := addr.NIC - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. + if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + // A local unicast address was specified, verify that it's valid. nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicid == 0 { return tcpip.ErrBadLocalAddress @@ -1056,3 +1074,7 @@ func (e *endpoint) State() uint32 { // TODO(b/112063468): Translate internal state to values returned by Linux. return 0 } + +func isBroadcastOrMulticast(a tcpip.Address) bool { + return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 18e786397..5cbb56120 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -64,7 +64,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - e.stack = stack.StackFromEnv + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s for _, m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { @@ -90,9 +95,10 @@ func (e *endpoint) afterLoad() { if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) if err != nil { - panic(*err) + panic(err) } - } else if len(e.id.LocalAddress) != 0 { // stateBound + } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } @@ -105,6 +111,6 @@ func (e *endpoint) afterLoad() { e.id.LocalPort = 0 e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) if err != nil { - panic(*err) + panic(err) } } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 56c285f88..9da6edce2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "fmt" "math" "math/rand" "testing" @@ -34,13 +35,19 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Addresses and ports used for testing. It is recommended that tests stick to +// using these addresses as it allows using the testFlow helper. +// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' +// represents the remote endpoint. const ( + v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + stackV4MappedAddr = v4MappedAddrPrefix + stackAddr + testV4MappedAddr = v4MappedAddrPrefix + testAddr + multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr + broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr + v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" stackAddr = "\x0a\x00\x00\x01" stackPort = 1234 @@ -48,7 +55,7 @@ const ( testPort = 4096 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - multicastPort = 1234 + broadcastAddr = header.IPv4Broadcast // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -56,6 +63,205 @@ const ( defaultMTU = 65536 ) +// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in +// a packet header. These values are used to populate a header or verify one. +// Note that because they are used in packet headers, the addresses are never in +// a V4-mapped format. +type header4Tuple struct { + srcAddr tcpip.FullAddress + dstAddr tcpip.FullAddress +} + +// testFlow implements a helper type used for sending and receiving test +// packets. A given test flow value defines 1) the socket endpoint used for the +// test and 2) the type of packet send or received on the endpoint. E.g., a +// multicastV6Only flow is a V6 multicast packet passing through a V6-only +// endpoint. The type provides helper methods to characterize the flow (e.g., +// isV4) as well as return a proper header4Tuple for it. +type testFlow int + +const ( + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket +) + +func (flow testFlow) String() string { + switch flow { + case unicastV4: + return "unicastV4" + case unicastV6: + return "unicastV6" + case unicastV6Only: + return "unicastV6Only" + case unicastV4in6: + return "unicastV4in6" + case multicastV4: + return "multicastV4" + case multicastV6: + return "multicastV6" + case multicastV6Only: + return "multicastV6Only" + case multicastV4in6: + return "multicastV4in6" + case broadcast: + return "broadcast" + case broadcastIn6: + return "broadcastIn6" + default: + return "unknown" + } +} + +// packetDirection explains if a flow is incoming (read) or outgoing (write). +type packetDirection int + +const ( + incoming packetDirection = iota + outgoing +) + +// header4Tuple returns the header4Tuple for the given flow and direction. Note +// that the tuple contains no mapped addresses as those only exist at the socket +// level but not at the packet header level. +func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { + var h header4Tuple + if flow.isV4() { + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastAddr + } else if flow.isBroadcast() { + h.dstAddr.Addr = broadcastAddr + } + } else { // IPv6 + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastV6Addr + } + } + return h +} + +func (flow testFlow) getMcastAddr() tcpip.Address { + if flow.isV4() { + return multicastAddr + } + return multicastV6Addr +} + +// mapAddrIfApplicable converts the given V4 address into its V4-mapped version +// if it is applicable to the flow. +func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { + if flow.isMapped() { + return v4MappedAddrPrefix + v4Addr + } + return v4Addr +} + +// netProto returns the protocol number used for the network packet. +func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { + if flow.isV4() { + return ipv4.ProtocolNumber + } + return ipv6.ProtocolNumber +} + +// sockProto returns the protocol number used when creating the socket +// endpoint for this flow. +func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { + switch flow { + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + return ipv6.ProtocolNumber + case unicastV4, multicastV4, broadcast: + return ipv4.ProtocolNumber + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { + if flow.isV4() { + return checker.IPv4 + } + return checker.IPv6 +} + +func (flow testFlow) isV6() bool { return !flow.isV4() } +func (flow testFlow) isV4() bool { + return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() +} + +func (flow testFlow) isV6Only() bool { + switch flow { + case unicastV6Only, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMulticast() bool { + switch flow { + case multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isBroadcast() bool { + switch flow { + case broadcast, broadcastIn6: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMapped() bool { + switch flow { + case unicastV4in6, multicastV4in6, broadcastIn6: + return true + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -65,12 +271,9 @@ type testContext struct { wq waiter.Queue } -type headers struct { - srcPort uint16 - dstPort uint16 -} - func newDualTestContext(t *testing.T, mtu uint32) *testContext { + t.Helper() + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) id, linkEP := channel.New(256, mtu, "") @@ -91,15 +294,11 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { s.SetRouteTable([]tcpip.Route{ { - Destination: "\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv4EmptySubnet, NIC: 1, }, { - Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - Gateway: "", + Destination: header.IPv6EmptySubnet, NIC: 1, }, }) @@ -117,51 +316,54 @@ func (c *testContext) cleanup() { } } -func (c *testContext) createV6Endpoint(v6only bool) { +func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { + c.t.Helper() + var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + c.t.Fatal("NewEndpoint failed: ", err) } +} - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) +func (c *testContext) createEndpointForFlow(flow testFlow) { + c.t.Helper() + + c.createEndpoint(flow.sockProto()) + if flow.isV6Only() { + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + } else if flow.isBroadcast() { + if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } } } -func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { +// getPacketAndVerify reads a packet from the link endpoint and verifies the +// header against expected values from the given test flow. In addition, it +// calls any extra checker functions provided. +func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { + c.t.Helper() + select { case p := <-c.linkEP.C: - if p.Proto != protocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) - var srcAddr, dstAddr tcpip.Address - switch protocolNumber { - case ipv4.ProtocolNumber: - checkerFn = checker.IPv4 - srcAddr, dstAddr = stackAddr, testAddr - if multicast { - dstAddr = multicastAddr - } - case ipv6.ProtocolNumber: - checkerFn = checker.IPv6 - srcAddr, dstAddr = stackV6Addr, testV6Addr - if multicast { - dstAddr = multicastV6Addr - } - default: - c.t.Fatalf("unknown protocol %d", protocolNumber) - } - checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) + h := flow.header4Tuple(outgoing) + checkers := append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) return b case <-time.After(2 * time.Second): @@ -171,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult return nil } -func (c *testContext) sendV6Packet(payload []byte, h *headers) { +// injectPacket creates a packet of the given flow and with the given payload, +// and injects it into the link endpoint. +func (c *testContext) injectPacket(flow testFlow, payload []byte) { + c.t.Helper() + + h := flow.header4Tuple(incoming) + if flow.isV4() { + c.injectV4Packet(payload, &h) + } else { + c.injectV6Packet(payload, &h) + } +} + +// injectV6Packet creates a V6 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -182,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -205,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -func (c *testContext) sendPacket(payload []byte, h *headers) { +// injectV6Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -217,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) { TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testAddr, - DstAddr: stackAddr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) ip.SetChecksum(^ip.CalculateChecksum()) // Initialize the UDP header. u := header.UDP(buf[header.IPv4MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -253,7 +472,7 @@ func TestBindPortReuse(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) var eps [5]tcpip.Endpoint reusePortOpt := tcpip.ReusePortOption(1) @@ -296,9 +515,9 @@ func TestBindPortReuse(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort + port, - dstPort: stackPort, + c.injectV6Packet(payload, &header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, }) var addr tcpip.FullAddress @@ -333,13 +552,14 @@ func TestBindPortReuse(t *testing.T) { } } -func testV4Read(c *testContext) { - // Send a packet. +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + payload := newPayload() - c.sendPacket(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) + c.injectPacket(flow, payload) // Try to receive the data. we, ch := waiter.NewChannelEntry(nil) @@ -363,8 +583,9 @@ func testV4Read(c *testContext) { } // Check the peer address. - if addr.Addr != testAddr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + h := flow.header4Tuple(incoming) + if addr.Addr != h.srcAddr.Addr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) } // Check the payload. @@ -377,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { t.Fatalf("ep.Bind(...) failed: %v", err) @@ -388,7 +609,7 @@ func TestBindReservedPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) @@ -447,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -455,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil { + if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -485,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV6ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - // Send a packet. - payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") - } - } - - // Check the peer address. - if addr.Addr != testV6Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) - } + // Test acceptance. + testRead(c, unicastV6) } func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - // Create v4 UDP endpoint. - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpointForFlow(unicastV4) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -555,62 +736,123 @@ func TestV4ReadOnV4(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4) } -func testV4Write(c *testContext) uint16 { - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != nil { - c.t.Fatalf("Write failed: %v", err) +// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast +// address and receive data sent to that address. +func TestReadOnBoundToMulticast(t *testing.T) { + // FIXME(b/128189410): multicastV4in6 currently doesn't work as + // AddMembershipOption doesn't handle V4in6 addresses. + for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to multicast address. + mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) + if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + // Join multicast group. + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } + + testRead(c, flow) + }) } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast +// address and receive broadcast data on it. +func TestV4ReadOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to broadcast address. + bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) + if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// testFailingWrite sends a packet of the given test flow into the UDP endpoint +// and verifies it fails with the provided error code. +func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { + c.t.Helper() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + + payload := buffer.View(newPayload()) + _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + }) + if gotErr != wantErr { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } +} - return udp.SourcePort() +// testWrite sends a packet of the given test flow from the UDP endpoint to the +// flow's destination address:port. It then receives it from the link endpoint +// and verifies its correctness including any additional checker functions +// provided. +func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, true, checkers...) } -func testV6Write(c *testContext) uint16 { - // Write to v6 address. +// testWriteWithoutDestination sends a packet of the given test flow from the +// UDP endpoint without giving a destination address:port. It then receives it +// from the link endpoint and verifies its correctness including any additional +// checker functions provided. +func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, false, checkers...) +} + +func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + + writeOpts := tcpip.WriteOptions{} + if setDest { + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + writeOpts = tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + } + } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("Write failed: %v", err) } - if n != uintptr(len(payload)) { + if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. + // Received the packet and check the payload. + b := c.getPacketAndVerify(flow, checkers...) + var udp header.UDP + if flow.isV4() { + udp = header.UDP(header.IPv4(b).Payload()) + } else { + udp = header.UDP(header.IPv6(b).Payload()) + } if !bytes.Equal(payload, udp.Payload()) { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } @@ -619,8 +861,10 @@ func testV6Write(c *testContext) uint16 { } func testDualWrite(c *testContext) uint16 { - v4Port := testV4Write(c) - v6Port := testV6Write(c) + c.t.Helper() + + v4Port := testWrite(c, unicastV4in6) + v6Port := testWrite(c, unicastV6) if v4Port != v6Port { c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) } @@ -632,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) } @@ -641,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -658,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV6Write(c) + testWrite(c, unicastV6) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNetworkUnreachable { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) } func TestDualWriteConnectedToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Write(c) + testWrite(c, unicastV4in6) // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV4WriteOnV6Only(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(true) + c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNoRoute { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -728,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV6WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } + testWriteWithoutDestination(c, unicastV6) } func TestV4WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) + testWriteWithoutDestination(c, unicastV4) +} + +// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket +// that is bound to a V4 multicast address. +func TestWriteOnBoundToV4Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) } - if n != uintptr(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a +// socket that is bound to a V4-mapped multicast address. +func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV6, multicastV6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// V6-only socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToBroadcast checks that we can send packets out of a +// socket that is bound to the broadcast address. +func TestWriteOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 broadcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a +// socket that is bound to the V4-mapped broadcast address. +func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } } @@ -814,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { defer c.cleanup() // Create IPv4 UDP endpoint - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Read(c) + testRead(c, unicastV4) var want uint64 = 1 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { @@ -837,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) @@ -847,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } -func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) { - for _, name := range []string{"v4", "v6", "dual"} { - t.Run(name, func(t *testing.T) { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch name { - case "v4": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } +func TestTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - var variants []string - switch name { - case "v4": - variants = []string{"v4"} - case "v6": - variants = []string{"v6"} - case "dual": - variants = []string{"v6", "mapped"} - } + c.createEndpointForFlow(flow) - for _, variant := range variants { - t.Run(variant, func(t *testing.T) { - optFunc(t, name, networkProtocolNumber, variant) - }) + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) } - }) - } -} -func TestTTL(t *testing.T) { - payload := tcpip.SlicePayload(buffer.View(newPayload())) - - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, typ := range []string{"unicast", "multicast"} { - t.Run(typ, func(t *testing.T) { - var addr tcpip.Address - var port uint16 - switch typ { - case "unicast": - port = testPort - switch variant { - case "v4": - addr = testAddr - case "mapped": - addr = testV4MappedAddr - case "v6": - addr = testV6Addr - default: - t.Fatal("unknown test variant") - } - case "multicast": - port = multicastPort - switch variant { - case "v4": - addr = multicastAddr - case "mapped": - addr = multicastV4MappedAddr - case "v6": - addr = multicastV6Addr - default: - t.Fatal("unknown test variant") - } - default: - t.Fatal("unknown test variant") + var wantTTL uint8 + if flow.isMulticast() { + wantTTL = multicastTTL + } else { + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - switch name { - case "v4": - case "v6": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - case "dual": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - default: - t.Fatal("unknown test variant") - } - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + t.Fatal(err) } + wantTTL = ep.DefaultTTL() + ep.Close() + } - n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) - } + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } +} - checkerFn := checker.IPv4 - switch variant { - case "v4", "mapped": - case "v6": - checkerFn = checker.IPv6 - default: - t.Fatal("unknown test variant") - } - var wantTTL uint8 - var multicast bool - switch typ { - case "unicast": - multicast = false - switch variant { - case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - default: - t.Fatal("unknown test variant") - } - case "multicast": - wantTTL = multicastTTL - multicast = true - default: - t.Fatal("unknown test variant") - } +func TestMulticastInterfaceOption(t *testing.T) { + for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, bindTyp := range []string{"bound", "unbound"} { + t.Run(bindTyp, func(t *testing.T) { + for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { + t.Run(optTyp, func(t *testing.T) { + h := flow.header4Tuple(outgoing) + mcastAddr := h.dstAddr.Addr + localIfAddr := h.srcAddr.Addr + + var ifoptSet tcpip.MulticastInterfaceOption + switch optTyp { + case "use local-addr": + ifoptSet.InterfaceAddr = localIfAddr + case "use NICID": + ifoptSet.NIC = 1 + case "use local-addr and NIC": + ifoptSet.InterfaceAddr = localIfAddr + ifoptSet.NIC = 1 + default: + t.Fatal("unknown test variant") + } - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch variant { - case "v4", "mapped": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(flow.sockProto()) + + if bindTyp == "bound" { + // Bind the socket by connecting to the multicast address. + // This may have an influence on how the multicast interface + // is set. + addr := tcpip.FullAddress{ + Addr: flow.mapAddrIfApplicable(mcastAddr), + Port: stackPort, + } + if err := c.ep.Connect(addr); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + } - b := c.getPacket(networkProtocolNumber, multicast) - checkerFn(c.t, b, - checker.TTL(wantTTL), - checker.UDP( - checker.DstPort(port), - ), - ) - }) - } - }) -} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } -func TestMulticastInterfaceOption(t *testing.T) { - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - var mcastAddr, localIfAddr tcpip.Address - switch variant { - case "v4": - mcastAddr = multicastAddr - localIfAddr = stackAddr - case "mapped": - mcastAddr = multicastV4MappedAddr - localIfAddr = stackAddr - case "v6": - mcastAddr = multicastV6Addr - localIfAddr = stackV6Addr - default: - t.Fatal("unknown test variant") - } - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: mcastAddr, - Port: multicastPort, + // Verify multicast interface addr and NIC were set correctly. + // Note that NIC must be 1 since this is our outgoing interface. + ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} + var ifoptGot tcpip.MulticastInterfaceOption + if err := c.ep.GetSockOpt(&ifoptGot); err != nil { + c.t.Fatalf("GetSockOpt failed: %v", err) } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + if ifoptGot != ifoptWant { + c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) + }) + } + }) + } + }) + } } |