summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go44
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go65
-rw-r--r--pkg/tcpip/header/ipv4.go9
-rw-r--r--pkg/tcpip/header/ipv6.go23
-rw-r--r--pkg/tcpip/iptables/BUILD5
-rw-r--r--pkg/tcpip/iptables/iptables.go4
-rw-r--r--pkg/tcpip/iptables/types.go19
-rw-r--r--pkg/tcpip/link/fdbased/BUILD4
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go179
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64.go194
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go23
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go (renamed from pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go)2
-rw-r--r--pkg/tcpip/link/rawfile/BUILD4
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s7
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s42
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go31
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_unsafe.go8
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go)10
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go6
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/network/arp/arp.go5
-rw-r--r--pkg/tcpip/network/arp/arp_test.go63
-rw-r--r--pkg/tcpip/network/ip_test.go6
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go28
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go33
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD1
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go9
-rw-r--r--pkg/tcpip/stack/BUILD24
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go253
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go79
-rw-r--r--pkg/tcpip/stack/nic.go306
-rw-r--r--pkg/tcpip/stack/route.go10
-rw-r--r--pkg/tcpip/stack/stack.go78
-rw-r--r--pkg/tcpip/stack/stack_test.go734
-rw-r--r--pkg/tcpip/stack/transport_test.go61
-rw-r--r--pkg/tcpip/tcpip.go101
-rw-r--r--pkg/tcpip/tcpip_test.go32
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go22
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go11
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go26
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go16
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go89
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go86
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go161
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/snd.go41
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go60
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go1104
60 files changed, 2698 insertions, 1483 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 047f8329a..df37c7d5a 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -12,6 +12,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/iptables",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index c40924852..0d2637ee4 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -24,6 +24,7 @@ go_test(
embed = [":gonet"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 308f620e5..cd6ce930a 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -404,7 +404,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
}
- var n uintptr
+ var n int64
var resCh <-chan struct{}
n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
nbytes += int(n)
@@ -556,32 +556,50 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
- // Create UDP endpoint and bind it.
+// DialUDP creates a new PacketConn.
+//
+// If laddr is nil, a local address is automatically chosen.
+//
+// If raddr is nil, the PacketConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
- if err := ep.Bind(addr); err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "bind",
- Net: "udp",
- Addr: fullToUDPAddr(addr),
- Err: errors.New(err.String()),
+ if laddr != nil {
+ if err := ep.Bind(*laddr); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "udp",
+ Addr: fullToUDPAddr(*laddr),
+ Err: errors.New(err.String()),
+ }
}
}
- c := &PacketConn{
+ c := PacketConn{
stack: s,
ep: ep,
wq: &wq,
}
c.deadlineTimer.init()
- return c, nil
+
+ if raddr != nil {
+ if err := c.ep.Connect(*raddr); err != nil {
+ c.ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "udp",
+ Addr: fullToUDPAddr(*raddr),
+ Err: errors.New(err.String()),
+ }
+ }
+ }
+
+ return &c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 39efe44c7..672f026b2 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -26,6 +26,7 @@ import (
"golang.org/x/net/nettest"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -69,17 +70,13 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
s.SetRouteTable([]tcpip.Route{
// IPv4
{
- Destination: tcpip.Address(strings.Repeat("\x00", 4)),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", 4)),
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: NICID,
},
// IPv6
{
- Destination: tcpip.Address(strings.Repeat("\x00", 16)),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", 16)),
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: NICID,
},
})
@@ -371,9 +368,9 @@ func TestUDPForwarder(t *testing.T) {
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
sent := "abc123"
@@ -452,13 +449,13 @@ func TestPacketConnTransfer(t *testing.T) {
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
- c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber)
+ c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 4):", err)
+ t.Fatal("DialUDP(bind port 4):", err)
}
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
c1.SetDeadline(time.Now().Add(time.Second))
@@ -491,6 +488,50 @@ func TestPacketConnTransfer(t *testing.T) {
}
}
+func TestConnectedPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 4):", err)
+ }
+ c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, err := c1.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
func makePipe() (c1, c2 net.Conn, stop func(), err error) {
s, e := newLoopbackStack()
if e != nil {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 94a3af289..17fc9c68e 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -111,6 +111,15 @@ const (
IPv4FlagDontFragment
)
+// IPv4EmptySubnet is the empty IPv4 subnet.
+var IPv4EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPVersion returns the version of IP used in the given packet. It returns -1
// if the packet is not large enough to contain the version field.
func IPVersion(b []byte) int {
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 95fe8bfc3..bc4e56535 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -27,7 +27,7 @@ const (
nextHdr = 6
hopLimit = 7
v6SrcAddr = 8
- v6DstAddr = 24
+ v6DstAddr = v6SrcAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -82,6 +82,15 @@ const (
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
)
+// IPv6EmptySubnet is the empty IPv6 subnet.
+var IPv6EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
@@ -110,13 +119,13 @@ func (b IPv6) Payload() []byte {
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
- return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
- return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
@@ -144,13 +153,13 @@ func (b IPv6) SetPayloadLength(payloadLength uint16) {
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+ copy(b[v6SrcAddr:][:IPv6AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+ copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
@@ -169,8 +178,8 @@ func (b IPv6) Encode(i *IPv6Fields) {
b.SetPayloadLength(i.PayloadLength)
b[nextHdr] = i.NextHeader
b[hopLimit] = i.HopLimit
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+ b.SetSourceAddress(i.SrcAddr)
+ b.SetDestinationAddress(i.DstAddr)
}
// IsValid performs basic validation on the packet.
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index fc9abbb55..3fc14bacd 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -11,8 +11,5 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- ],
+ deps = ["//pkg/tcpip/buffer"],
)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index f1e1d1fad..68c68d4aa 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -32,8 +32,8 @@ const (
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() *IPTables {
- return &IPTables{
+func DefaultTables() IPTables {
+ return IPTables{
Tables: map[string]Table{
tablenameNat: Table{
BuiltinChains: map[Hook]Chain{
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 600bd9a10..42a79ef9f 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -15,7 +15,6 @@
package iptables
import (
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -128,15 +127,29 @@ type Table struct {
// UserChains, and its purpose is to make looking up tables by name
// fast.
Chains map[string]*Chain
+
+ // Metadata holds information about the Table that is useful to users
+ // of IPTables, but not to the netstack IPTables code itself.
+ metadata interface{}
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
-func (table *Table) ValidHooks() (uint32, *tcpip.Error) {
+func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
for hook, _ := range table.BuiltinChains {
hooks |= 1 << hook
}
- return hooks, nil
+ return hooks
+}
+
+// Metadata returns the metadata object stored in table.
+func (table *Table) Metadata() interface{} {
+ return table.metadata
+}
+
+// SetMetadata sets the metadata object stored in table.
+func (table *Table) SetMetadata(metadata interface{}) {
+ table.metadata = metadata
}
// A Chain defines a list of rules for packet processing. When a packet
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d786d8fdf..74fbbb896 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -8,8 +8,8 @@ go_library(
"endpoint.go",
"endpoint_unsafe.go",
"mmap.go",
- "mmap_amd64.go",
- "mmap_amd64_unsafe.go",
+ "mmap_stub.go",
+ "mmap_unsafe.go",
"packet_dispatchers.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index fe19c2bc2..8bfeb97e4 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -12,14 +12,183 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !linux !amd64
+// +build linux,amd64 linux,arm64
package fdbased
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "encoding/binary"
+ "syscall"
-// Stubbed out version for non-linux/non-amd64 platforms.
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
-func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, *tcpip.Error) {
- return nil, nil
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go
deleted file mode 100644
index 8bbb4f9ab..000000000
--- a/pkg/tcpip/link/fdbased/mmap_amd64.go
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux,amd64
-
-package fdbased
-
-import (
- "encoding/binary"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
-)
-
-const (
- tPacketAlignment = uintptr(16)
- tpStatusKernel = 0
- tpStatusUser = 1
- tpStatusCopy = 2
- tpStatusLosing = 4
-)
-
-// We overallocate the frame size to accommodate space for the
-// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
-//
-// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
-//
-// NOTE:
-// Frames need to be aligned at 16 byte boundaries.
-// BlockSize needs to be page aligned.
-//
-// For details see PACKET_MMAP setting constraints in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-const (
- tpFrameSize = 65536 + 128
- tpBlockSize = tpFrameSize * 32
- tpBlockNR = 1
- tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
-)
-
-// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
-// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
-func tPacketAlign(v uintptr) uintptr {
- return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
-}
-
-// tPacketReq is the tpacket_req structure as described in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-type tPacketReq struct {
- tpBlockSize uint32
- tpBlockNR uint32
- tpFrameSize uint32
- tpFrameNR uint32
-}
-
-// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
-type tPacketHdr []byte
-
-const (
- tpStatusOffset = 0
- tpLenOffset = 8
- tpSnapLenOffset = 12
- tpMacOffset = 16
- tpNetOffset = 18
- tpSecOffset = 20
- tpUSecOffset = 24
-)
-
-func (t tPacketHdr) tpLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpLenOffset:])
-}
-
-func (t tPacketHdr) tpSnapLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
-}
-
-func (t tPacketHdr) tpMac() uint16 {
- return binary.LittleEndian.Uint16(t[tpMacOffset:])
-}
-
-func (t tPacketHdr) tpNet() uint16 {
- return binary.LittleEndian.Uint16(t[tpNetOffset:])
-}
-
-func (t tPacketHdr) tpSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpSecOffset:])
-}
-
-func (t tPacketHdr) tpUSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpUSecOffset:])
-}
-
-func (t tPacketHdr) Payload() []byte {
- return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
-}
-
-// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
-// See: mmap_amd64_unsafe.go for implementation details.
-type packetMMapDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
- // ringBuffer is only used when PacketMMap dispatcher is used and points
- // to the start of the mmapped PACKET_RX_RING buffer.
- ringBuffer []byte
-
- // ringOffset is the current offset into the ring buffer where the next
- // inbound packet will be placed by the kernel.
- ringOffset int
-}
-
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
- hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
- for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, -1); errno != 0 {
- if errno == syscall.EINTR {
- continue
- }
- return nil, rawfile.TranslateErrno(errno)
- }
- if hdr.tpStatus()&tpStatusCopy != 0 {
- // This frame is truncated so skip it after flipping the
- // buffer to the kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
- continue
- }
- }
-
- // Copy out the packet from the mmapped frame to a locally owned buffer.
- pkt := make([]byte, hdr.tpSnapLen())
- copy(pkt, hdr.Payload())
- // Release packet to kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
-}
-
-// dispatch reads packets from an mmaped ring buffer and dispatches them to the
-// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
- return false, err
- }
- var (
- p tcpip.NetworkProtocolNumber
- remote, local tcpip.LinkAddress
- )
- if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
- p = eth.Type()
- remote = eth.SourceAddress()
- local = eth.DestinationAddress()
- } else {
- // We don't get any indication of what the packet is, so try to guess
- // if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(pkt) {
- case header.IPv4Version:
- p = header.IPv4ProtocolNumber
- case header.IPv6Version:
- p = header.IPv6ProtocolNumber
- default:
- return true, nil
- }
- }
-
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
- return true, nil
-}
diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go
new file mode 100644
index 000000000..67be52d67
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux !amd64,!arm64
+
+package fdbased
+
+// Stubbed out version for non-linux/non-amd64/non-arm64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 47cb1d1cc..3894185ae 100644
--- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
package fdbased
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 6e3a7a9d7..088eb8a21 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -6,8 +6,10 @@ go_library(
name = "rawfile",
srcs = [
"blockingpoll_amd64.s",
- "blockingpoll_amd64_unsafe.go",
+ "blockingpoll_arm64.s",
+ "blockingpoll_noyield_unsafe.go",
"blockingpoll_unsafe.go",
+ "blockingpoll_yield_unsafe.go",
"errors.go",
"rawfile_unsafe.go",
],
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
index b54131573..298bad55d 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -14,17 +14,18 @@
#include "textflag.h"
-// BlockingPoll makes the poll() syscall while calling the version of
+// BlockingPoll makes the ppoll() syscall while calling the version of
// entersyscall that relinquishes the P so that other Gs can run. This is meant
// to be called in cases when the syscall is expected to block.
//
-// func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (n int, err syscall.Errno)
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
CALL ·callEntersyscallblock(SB)
MOVQ fds+0(FP), DI
MOVQ nfds+8(FP), SI
MOVQ timeout+16(FP), DX
- MOVQ $0x7, AX // SYS_POLL
+ MOVQ $0x0, R10 // sigmask parameter which isn't used here
+ MOVQ $0x10f, AX // SYS_PPOLL
SYSCALL
CMPQ AX, $0xfffffffffffff001
JLS ok
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
new file mode 100644
index 000000000..b62888b93
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ BL ·callEntersyscallblock(SB)
+ MOVD fds+0(FP), R0
+ MOVD nfds+8(FP), R1
+ MOVD timeout+16(FP), R2
+ MOVD $0x0, R3 // sigmask parameter which isn't used here
+ MOVD $0x49, R8 // SYS_PPOLL
+ SVC
+ CMP $0xfffffffffffff001, R0
+ BLS ok
+ MOVD $-1, R1
+ MOVD R1, n+24(FP)
+ NEG R0, R0
+ MOVD R0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVD R0, n+24(FP)
+ MOVD $0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
new file mode 100644
index 000000000..621ab8d29
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,!amd64,!arm64
+
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// BlockingPoll is just a stub function that forwards to the ppoll() system call
+// on non-amd64 and non-arm64 platforms.
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
+ uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
+
+ return int(n), e
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
index 4eab77c74..84dc0e918 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
@@ -21,9 +21,11 @@ import (
"unsafe"
)
-// BlockingPoll is just a stub function that forwards to the poll() system call
+// BlockingPoll is just a stub function that forwards to the ppoll() system call
// on non-amd64 platforms.
-func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) {
- n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout))
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
+ uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
+
return int(n), e
}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index c87268610..dda3b10a6 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
// +build go1.12
// +build !go1.14
@@ -25,8 +25,14 @@ import (
_ "unsafe" // for go:linkname
)
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
//go:noescape
-func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno)
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
// Use go:linkname to call into the runtime. As of Go 1.12 this has to
// be done from Go code so that we make an ABIInternal call to an
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index e3fbb15c2..7e286a3a6 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -123,7 +123,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
Events: 1, // POLLIN
}
- _, e = BlockingPoll(&event, 1, -1)
+ _, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
@@ -145,7 +145,7 @@ func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
Events: 1, // POLLIN
}
- _, e = BlockingPoll(&event, 1, -1)
+ _, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
@@ -175,7 +175,7 @@ func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
Events: 1, // POLLIN
}
- if _, e := BlockingPoll(&event, 1, -1); e != 0 && e != syscall.EINTR {
+ if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index fc584c6a4..36c8c46fc 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -360,10 +360,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
+ details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
+ size -= header.UDPMinimumSize
}
- size -= header.UDPMinimumSize
-
- details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
case header.TCPProtocolNumber:
transName = "tcp"
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 60070874d..fd6395fc1 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -109,13 +109,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
pkt.SetOp(header.ARPReply)
copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
- fallthrough // also fill the cache from requests
case header.ARPReply:
- addr := tcpip.Address(h.ProtocolAddressSender())
- linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 66c55821b..4c4b54469 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,6 +15,7 @@
package arp_test
import (
+ "strconv"
"testing"
"time"
@@ -65,9 +66,7 @@ func newTestContext(t *testing.T) *testContext {
}
s.SetRouteTable([]tcpip.Route{{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
}})
@@ -101,40 +100,30 @@ func TestDirectRequest(t *testing.T) {
c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView())
}
- inject(stackAddr1)
- {
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
- }
- rep := header.ARP(pkt.Header)
- if !rep.IsValid() {
- t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
- }
- if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
- t.Errorf("stackAddr1: expected sender to be set")
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
- }
- }
-
- inject(stackAddr2)
- {
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
- }
- rep := header.ARP(pkt.Header)
- if !rep.IsValid() {
- t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
- }
- if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
- t.Errorf("stackAddr2: expected sender to be set")
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
- }
+ for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ inject(address)
+ pkt := <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto)
+ }
+ rep := header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want)
+ }
+ })
}
inject(stackAddrBad)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 55e9eec99..6bbfcd97f 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -173,8 +173,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
- Destination: ipv4SubnetAddr,
- Mask: ipv4SubnetMask,
+ Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
NIC: 1,
}})
@@ -187,8 +186,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
- Destination: ipv6SubnetAddr,
- Mask: ipv6SubnetMask,
+ Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
NIC: 1,
}})
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index fbef6947d..497164cbb 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -94,6 +94,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv4EchoReply)
+ pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 3207a3d46..1b5a55bea 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -52,9 +52,7 @@ func TestExcludeBroadcast(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
}})
@@ -247,14 +245,22 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
_, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
linkEPId := stack.RegisterLinkEndpoint(linkEP)
s.CreateNIC(1, linkEPId)
- s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01")
- s.SetRouteTable([]tcpip.Route{{
- Destination: "\x10\x00\x00\x02",
- Mask: "\xff\xff\xff\xff",
- Gateway: "",
- NIC: 1,
- }})
- r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */)
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ s.AddAddress(1, ipv4.ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("s.FindRoute got %v, want %v", err, nil)
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 5e6a59e91..1689af16f 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -100,13 +100,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
@@ -146,7 +144,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 726362c87..d0dc72506 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -91,13 +91,18 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
}
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: lladdr1,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
- NIC: 1,
- }},
- )
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
netProto := s.NetworkProtocolInstance(ProtocolNumber)
if netProto == nil {
@@ -237,17 +242,23 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress sn lladdr1: %v", err)
}
+ subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
c.s0.SetRouteTable(
[]tcpip.Route{{
- Destination: lladdr1,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
+ Destination: subnet0,
NIC: 1,
}},
)
+ subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
c.s1.SetRouteTable(
[]tcpip.Route{{
- Destination: lladdr0,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
+ Destination: subnet1,
NIC: 1,
}},
)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index 996939581..a57752a7c 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -8,6 +8,7 @@ go_binary(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/rawfile",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 3ac381631..e2021cd15 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -52,6 +52,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -152,9 +153,7 @@ func main() {
// Add default route.
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index da425394a..1716be285 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -149,12 +149,15 @@ func main() {
log.Fatal(err)
}
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ if err != nil {
+ log.Fatal(err)
+ }
+
// Add default route.
s.SetRouteTable([]tcpip.Route{
{
- Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))),
- Gateway: "",
+ Destination: subnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 28d11c797..b692c60ce 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,11 +1,25 @@
package(licenses = ["notice"])
+load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+go_template_instance(
+ name = "linkaddrentry_list",
+ out = "linkaddrentry_list.go",
+ package = "stack",
+ prefix = "linkAddrEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*linkAddrEntry",
+ "Linker": "*linkAddrEntry",
+ },
+)
+
go_library(
name = "stack",
srcs = [
"linkaddrcache.go",
+ "linkaddrentry_list.go",
"nic.go",
"registration.go",
"route.go",
@@ -24,6 +38,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
@@ -42,6 +57,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/waiter",
@@ -58,3 +74,11 @@ go_test(
"//pkg/tcpip",
],
)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "linkaddrentry_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 77bb0ccb9..267df60d1 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -42,10 +42,11 @@ type linkAddrCache struct {
// resolved before failing.
resolutionAttempts int
- mu sync.Mutex
- cache map[tcpip.FullAddress]*linkAddrEntry
- next int // array index of next available entry
- entries [linkAddrCacheSize]linkAddrEntry
+ cache struct {
+ sync.Mutex
+ table map[tcpip.FullAddress]*linkAddrEntry
+ lru linkAddrEntryList
+ }
}
// entryState controls the state of a single entry in the cache.
@@ -60,9 +61,6 @@ const (
// failed means that address resolution timed out and the address
// could not be resolved.
failed
- // expired means that the cache entry has expired and the address must be
- // resolved again.
- expired
)
// String implements Stringer.
@@ -74,8 +72,6 @@ func (s entryState) String() string {
return "ready"
case failed:
return "failed"
- case expired:
- return "expired"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -84,64 +80,46 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ linkAddrEntryEntry
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
// wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of 'incomplete' these waiters are notified.
+ // state transitions out of incomplete these waiters are notified.
wakers map[*sleep.Waker]struct{}
+ // done is used to allow callers to wait on address resolution. It is nil iff
+ // s is incomplete and resolution is not yet in progress.
done chan struct{}
}
-func (e *linkAddrEntry) state() entryState {
- if e.s != expired && time.Now().After(e.expiration) {
- // Force the transition to ensure waiters are notified.
- e.changeState(expired)
- }
- return e.s
-}
-
-func (e *linkAddrEntry) changeState(ns entryState) {
- if e.s == ns {
- return
- }
-
- // Validate state transition.
- switch e.s {
- case incomplete:
- // All transitions are valid.
- case ready, failed:
- if ns != expired {
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- }
- case expired:
- // Terminal state.
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- default:
- panic(fmt.Sprintf("invalid state: %s", e.s))
- }
-
+// changeState sets the entry's state to ns, notifying any waiters.
+//
+// The entry's expiration is bumped up to the greater of itself and the passed
+// expiration; the zero value indicates immediate expiration, and is set
+// unconditionally - this is an implementation detail that allows for entries
+// to be reused.
+func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
// Notify whoever is waiting on address resolution when transitioning
- // out of 'incomplete'.
- if e.s == incomplete {
+ // out of incomplete.
+ if e.s == incomplete && ns != incomplete {
for w := range e.wakers {
w.Assert()
}
e.wakers = nil
- if e.done != nil {
- close(e.done)
+ if ch := e.done; ch != nil {
+ close(ch)
}
+ e.done = nil
}
- e.s = ns
-}
-func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
- if w != nil {
- e.wakers[w] = struct{}{}
+ if expiration.IsZero() || expiration.After(e.expiration) {
+ e.expiration = expiration
}
+ e.s = ns
}
func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
@@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
- if ok {
- s := entry.state()
- if s != expired && entry.linkAddr == v {
- // Disregard repeated calls.
- return
- }
- // Check if entry is waiting for address resolution.
- if s == incomplete {
- entry.linkAddr = v
- } else {
- // Otherwise create a new entry to replace it.
- entry = c.makeAndAddEntry(k, v)
- }
- } else {
- entry = c.makeAndAddEntry(k, v)
- }
+ // Calculate expiration time before acquiring the lock, since expiration is
+ // relative to the time when information was learned, rather than when it
+ // happened to be inserted into the cache.
+ expiration := time.Now().Add(c.ageLimit)
- entry.changeState(ready)
+ c.cache.Lock()
+ entry := c.getOrCreateEntryLocked(k)
+ entry.linkAddr = v
+
+ entry.changeState(ready, expiration)
+ c.cache.Unlock()
}
-// makeAndAddEntry is a helper function to create and add a new
-// entry to the cache map and evict older entry as needed.
-func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
- // Take over the next entry.
- entry := &c.entries[c.next]
- if c.cache[entry.addr] == entry {
- delete(c.cache, entry.addr)
+// getOrCreateEntryLocked retrieves a cache entry associated with k. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+ if entry, ok := c.cache.table[k]; ok {
+ c.cache.lru.Remove(entry)
+ c.cache.lru.PushFront(entry)
+ return entry
}
+ var entry *linkAddrEntry
+ if len(c.cache.table) == linkAddrCacheSize {
+ entry = c.cache.lru.Back()
- // Mark the soon-to-be-replaced entry as expired, just in case there is
- // someone waiting for address resolution on it.
- entry.changeState(expired)
+ delete(c.cache.table, entry.addr)
+ c.cache.lru.Remove(entry)
- *entry = linkAddrEntry{
- addr: k,
- linkAddr: v,
- expiration: time.Now().Add(c.ageLimit),
- wakers: make(map[*sleep.Waker]struct{}),
- done: make(chan struct{}),
+ // Wake waiters and mark the soon-to-be-reused entry as expired. Note
+ // that the state passed doesn't matter when the zero time is passed.
+ entry.changeState(failed, time.Time{})
+ } else {
+ entry = new(linkAddrEntry)
}
- c.cache[k] = entry
- c.next = (c.next + 1) % len(c.entries)
+ *entry = linkAddrEntry{
+ addr: k,
+ s: incomplete,
+ }
+ c.cache.table[k] = entry
+ c.cache.lru.PushFront(entry)
return entry
}
@@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
}
- c.mu.Lock()
- defer c.mu.Unlock()
- if entry, ok := c.cache[k]; ok {
- switch s := entry.state(); s {
- case expired:
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return "", nil, tcpip.ErrNoLinkAddress
- case incomplete:
- // Address resolution is still in progress.
- entry.maybeAddWaker(waker)
- return "", entry.done, tcpip.ErrWouldBlock
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry := c.getOrCreateEntryLocked(k)
+ switch s := entry.s; s {
+ case ready, failed:
+ if !time.Now().After(entry.expiration) {
+ // Not expired.
+ switch s {
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
- }
- if linkRes == nil {
- return "", nil, tcpip.ErrNoLinkAddress
- }
+ entry.changeState(incomplete, time.Time{})
+ fallthrough
+ case incomplete:
+ if waker != nil {
+ if entry.wakers == nil {
+ entry.wakers = make(map[*sleep.Waker]struct{})
+ }
+ entry.wakers[waker] = struct{}{}
+ }
- // Add 'incomplete' entry in the cache to mark that resolution is in progress.
- e := c.makeAndAddEntry(k, "")
- e.maybeAddWaker(waker)
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ }
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ entry.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ }
- return "", e.done, tcpip.ErrWouldBlock
+ return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
// removeWaker removes a waker previously added through get().
func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.cache.Lock()
+ defer c.cache.Unlock()
- if entry, ok := c.cache[k]; ok {
+ if entry, ok := c.cache.table[k]; ok {
entry.removeWaker(waker)
}
}
@@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
select {
- case <-time.After(c.resolutionTimeout):
- if stop := c.checkLinkRequest(k, i); stop {
+ case now := <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(now, k, i); stop {
return
}
case <-done:
@@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
// can stop, false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry, ok := c.cache.table[k]
if !ok {
// Entry was evicted from the cache.
return true
}
-
- switch s := entry.state(); s {
- case ready, failed, expired:
+ switch s := entry.s; s {
+ case ready, failed:
// Entry was made ready by resolver or failed. Either way we're done.
- return true
case incomplete:
- if attempt+1 >= c.resolutionAttempts {
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed)
- return true
+ if attempt+1 < c.resolutionAttempts {
+ // No response yet, need to send another ARP request.
+ return false
}
- // No response yet, need to send another ARP request.
- return false
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed, now.Add(c.ageLimit))
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
+ return true
}
func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
- return &linkAddrCache{
+ c := &linkAddrCache{
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
- cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
}
+ c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 924f4d240..9946b8fe8 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -29,25 +30,34 @@ type testaddr struct {
linkAddr tcpip.LinkAddress
}
-var testaddrs []testaddr
+var testAddrs = func() []testaddr {
+ var addrs []testaddr
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ addrs = append(addrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+ return addrs
+}()
type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
+ cache *linkAddrCache
+ delay time.Duration
+ onLinkAddressRequest func()
}
func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- go func() {
- if r.delay > 0 {
- time.Sleep(r.delay)
- }
- r.fakeRequest(addr)
- }()
+ time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
return nil
}
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testaddrs {
+ for _, ta := range testAddrs {
if ta.addr.Addr == addr {
r.cache.add(ta.addr, ta.linkAddr)
break
@@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
}
}
-func init() {
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- testaddrs = append(testaddrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
-}
-
func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testaddrs) - 1; i >= 0; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= 0; i-- {
+ e := testAddrs[i]
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
@@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) {
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testaddrs[i]
+ e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
@@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
- for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
+ e := testAddrs[i]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
@@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) {
for r := 0; r < 16; r++ {
wg.Add(1)
go func() {
- for _, e := range testaddrs {
+ for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
@@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
- e = testaddrs[0]
+ e = testAddrs[0]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
@@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) {
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
@@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) {
func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
@@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) {
func TestCacheResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testaddrs {
+ for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
@@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) {
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
+ var requestCount uint32
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
+ }
+
// First, sanity check that resolution is working...
- e := testaddrs[0]
+ e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
+ before := atomic.LoadUint32(&requestCount)
+
e.addr.Addr += "2"
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
+
+ if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
}
func TestCacheResolutionTimeout(t *testing.T) {
@@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) {
c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
- e := testaddrs[0]
+ e := testAddrs[0]
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3e6ff4afb..89b4c5960 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add
if list, ok := n.primary[protocol]; ok {
for e := list.Front(); e != nil; e = e.Next() {
ref := e.(*referencedNetworkEndpoint)
- if ref.holdsInsertRef && ref.tryIncRef() {
+ if ref.kind == permanent && ref.tryIncRef() {
r = ref
break
}
@@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
case header.IPv4Broadcast, header.IPv4Any:
continue
}
- if r.tryIncRef() {
+ if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
}
@@ -186,82 +186,155 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
+}
+
+// getRefEpOrCreateTemp returns the referenced network endpoint for the given
+// protocol and address. If none exists a temporary one may be created if
+// we are in promiscuous mode or spoofing.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- ref := n.endpoints[id]
- if ref != nil && !ref.tryIncRef() {
- ref = nil
+
+ if ref, ok := n.endpoints[id]; ok {
+ // An endpoint with this id exists, check if it can be used and return it.
+ switch ref.kind {
+ case permanentExpired:
+ if !spoofingOrPromiscuous {
+ n.mu.RUnlock()
+ return nil
+ }
+ fallthrough
+ case temporary, permanent:
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
+ }
+
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
+ if !createTempEP {
+ for _, sn := range n.subnets {
+ if sn.Contains(address) {
+ createTempEP = true
+ break
+ }
+ }
}
- spoofing := n.spoofing
+
n.mu.RUnlock()
- if ref != nil || !spoofing {
- return ref
+ if !createTempEP {
+ return nil
}
// Try again with the lock in exclusive mode. If we still can't get the
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- ref = n.endpoints[id]
- if ref == nil || !ref.tryIncRef() {
- if netProto, ok := n.stack.networkProtocols[protocol]; ok {
- addrWithPrefix := tcpip.AddressWithPrefix{address, netProto.DefaultPrefixLen()}
- ref, _ = n.addAddressLocked(protocol, addrWithPrefix, peb, true)
- if ref != nil {
- ref.holdsInsertRef = false
- }
+ if ref, ok := n.endpoints[id]; ok {
+ // No need to check the type as we are ok with expired endpoints at this
+ // point.
+ if ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
}
+ // tryIncRef failing means the endpoint is scheduled to be removed once the
+ // lock is released. Remove it here so we can create a new (temporary) one.
+ // The removal logic waiting for the lock handles this case.
+ n.removeEndpointLocked(ref)
}
- n.mu.Unlock()
- return ref
-}
-func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
+ n.mu.Unlock()
+ return nil
}
+ ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb, temporary)
- // Create the new network endpoint.
- ep, err := netProto.NewEndpoint(n.id, addrWithPrefix, n.stack, n, n.linkEP)
- if err != nil {
- return nil, err
- }
+ n.mu.Unlock()
+ return ref
+}
- id := *ep.ID()
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if ref, ok := n.endpoints[id]; ok {
- if !replace {
+ switch ref.kind {
+ case permanent:
+ // The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
+ case permanentExpired, temporary:
+ // Promote the endpoint to become permanent.
+ if ref.tryIncRef() {
+ ref.kind = permanent
+ return ref, nil
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once
+ // the lock is released. Remove it here so we can create a new
+ // (permanent) one. The removal logic waiting for the lock handles this
+ // case.
+ n.removeEndpointLocked(ref)
}
+ }
+ return n.addAddressLocked(protocolAddress, peb, permanent)
+}
- n.removeEndpointLocked(ref)
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // Sanity check.
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if _, ok := n.endpoints[id]; ok {
+ // Endpoint already exists.
+ return nil, tcpip.ErrDuplicateAddress
}
+ netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP)
+ if err != nil {
+ return nil, err
+ }
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocol,
- holdsInsertRef: true,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
}
// Set up cache if link address resolution exists for this protocol.
if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if _, ok := n.stack.linkAddrResolvers[protocol]; ok {
+ if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
ref.linkCache = n.stack
}
}
n.endpoints[id] = ref
- l, ok := n.primary[protocol]
+ l, ok := n.primary[protocolAddress.Protocol]
if !ok {
l = &ilist.List{}
- n.primary[protocol] = l
+ n.primary[protocolAddress.Protocol] = l
}
switch peb {
@@ -276,10 +349,10 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPre
// AddAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocol, addrWithPrefix, peb, false)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb)
n.mu.Unlock()
return err
@@ -291,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
defer n.mu.RUnlock()
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.kind {
+ case permanentExpired, temporary:
+ continue
+ }
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -356,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet {
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
- // Nothing to do if the reference has already been replaced with a
- // different one.
+ // Nothing to do if the reference has already been replaced with a different
+ // one. This happens in the case where 1) this endpoint's ref count hit zero
+ // and was waiting (on the lock) to be removed and 2) the same address was
+ // re-added in the meantime by removing this endpoint from the list and
+ // adding a new one.
if n.endpoints[id] != r {
return
}
- if r.holdsInsertRef {
+ if r.kind == permanent {
panic("Reference count dropped to zero before being removed")
}
@@ -381,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
+func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
r := n.endpoints[NetworkEndpointID{addr}]
- if r == nil || !r.holdsInsertRef {
+ if r == nil || r.kind != permanent {
return tcpip.ErrBadLocalAddress
}
- r.holdsInsertRef = false
-
+ r.kind = permanentExpired
r.decRefLocked()
return nil
@@ -398,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.removeAddressLocked(addr)
+ return n.removePermanentAddressLocked(addr)
}
// joinGroup adds a new endpoint for the given multicast address, if none
@@ -414,8 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
if !ok {
return tcpip.ErrUnknownProtocol
}
- addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()}
- if _, err := n.addAddressLocked(protocol, addrWithPrefix, NeverPrimaryEndpoint, false); err != nil {
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, NeverPrimaryEndpoint); err != nil {
return err
}
}
@@ -437,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
case 1:
// This is the last one, clean up.
- if err := n.removeAddressLocked(addr); err != nil {
+ if err := n.removePermanentAddressLocked(addr); err != nil {
return err
}
}
@@ -445,6 +531,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return nil
}
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+ r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remotelinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the physical interface.
@@ -472,6 +565,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
src, dst := netProto.ParseAddresses(vv.First())
+ n.stack.AddLinkAddress(n.id, src, remote)
+
// If the packet is destined to the IPv4 Broadcast address, then make a
// route to each IPv4 network endpoint and let each endpoint handle the
// packet.
@@ -479,11 +574,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
// n.endpoints is mutex protected so acquire lock.
n.mu.RLock()
for _, ref := range n.endpoints {
- if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
}
}
n.mu.RUnlock()
@@ -491,10 +583,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
}
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
}
@@ -517,8 +606,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
n.mu.RUnlock()
- if ok && ref.tryIncRef() {
+ if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, vv)
@@ -543,52 +633,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
-func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- id := NetworkEndpointID{dst}
-
- n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
-
- promiscuous := n.promiscuous
- // Check if the packet is for a subnet this NIC cares about.
- if !promiscuous {
- for _, sn := range n.subnets {
- if sn.Contains(dst) {
- promiscuous = true
- break
- }
- }
- }
- n.mu.RUnlock()
- if promiscuous {
- // Try again with the lock in exclusive mode. If we still can't
- // get the endpoint, create a new "temporary" one. It will only
- // exist while there's a route through it.
- n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.Unlock()
- return ref
- }
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- n.mu.Unlock()
- return nil
- }
- addrWithPrefix := tcpip.AddressWithPrefix{dst, netProto.DefaultPrefixLen()}
- ref, err := n.addAddressLocked(protocol, addrWithPrefix, CanBePrimaryEndpoint, true)
- n.mu.Unlock()
- if err == nil {
- ref.holdsInsertRef = false
- return ref
- }
- }
-
- return nil
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
@@ -676,9 +720,33 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+type networkEndpointKind int
+
+const (
+ // A permanent endpoint is created by adding a permanent address (vs. a
+ // temporary one) to the NIC. Its reference count is biased by 1 to avoid
+ // removal when no route holds a reference to it. It is removed by explicitly
+ // removing the permanent address from the NIC.
+ permanent networkEndpointKind = iota
+
+ // An expired permanent endoint is a permanent endoint that had its address
+ // removed from the NIC, and it is waiting to be removed once no more routes
+ // hold a reference to it. This is achieved by decreasing its reference count
+ // by 1. If its address is re-added before the endpoint is removed, its type
+ // changes back to permanent and its reference count increases by 1 again.
+ permanentExpired
+
+ // A temporary endpoint is created for spoofing outgoing packets, or when in
+ // promiscuous mode and accepting incoming packets that don't match any
+ // permanent endpoint. Its reference count is not biased by 1 and the
+ // endpoint is removed immediately when no more route holds a reference to
+ // it. A temporary endpoint can be promoted to permanent if its address
+ // is added permanently.
+ temporary
+)
+
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -687,11 +755,25 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
- // holdsInsertRef is protected by the NIC's mutex. It indicates whether
- // the reference count is biased by 1 due to the insertion of the
- // endpoint. It is reset to false when RemoveAddress is called on the
- // NIC.
- holdsInsertRef bool
+ // refs is counting references held for this endpoint. When refs hits zero it
+ // triggers the automatic removal of the endpoint from the NIC.
+ refs int32
+
+ kind networkEndpointKind
+}
+
+// isValidForOutgoing returns true if the endpoint can be used to send out a
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in spoofing mode.
+func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
+ return r.kind != permanentExpired || r.nic.spoofing
+}
+
+// isValidForIncoming returns true if the endpoint can accept an incoming
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in promiscuous mode.
+func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
+ return r.kind != permanentExpired || r.nic.promiscuous
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 391ab4344..e52cdd674 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
func (r *Route) IsResolutionRequired() bool {
- return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
@@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 57b8a9994..d69162ba1 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
@@ -333,6 +334,15 @@ type TCPEndpointState struct {
Sender TCPSenderState
}
+// ResumableEndpoint is an endpoint that needs to be resumed after restore.
+type ResumableEndpoint interface {
+ // Resume resumes an endpoint after restore. This can be used to restart
+ // background workers such as protocol goroutines. This must be called after
+ // all indirect dependencies of the endpoint has been restored, which
+ // generally implies at the end of the restore process.
+ Resume(*Stack)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -372,6 +382,13 @@ type Stack struct {
// handleLocal allows non-loopback interfaces to loop packets.
handleLocal bool
+
+ // tables are the iptables packet filtering and manipulation rules.
+ tables iptables.IPTables
+
+ // resumableEndpoints is a list of endpoints that need to be resumed if the
+ // stack is being restored.
+ resumableEndpoints []ResumableEndpoint
}
// Options contains optional Stack configuration.
@@ -751,10 +768,10 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
-// AddAddressWithPrefix adds a new network-layer address/prefixLen to the
+// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix) *tcpip.Error {
- return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, CanBePrimaryEndpoint)
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
+ return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
}
// AddAddressWithOptions is the same as AddAddress, but allows you to specify
@@ -764,13 +781,18 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt
if !ok {
return tcpip.ErrUnknownProtocol
}
- addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()}
- return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, peb)
-}
-
-// AddAddressWithPrefixAndOptions is the same as AddAddressWithPrefixLen,
-// but allows you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error {
+ return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb)
+}
+
+// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
+// you to specify whether the new endpoint can be primary or not.
+func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -779,7 +801,7 @@ func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.Ne
return tcpip.ErrUnknownNICID
}
- return nic.AddAddress(protocol, addrWithPrefix, peb)
+ return nic.AddAddress(protocolAddress, peb)
}
// AddSubnet adds a subnet range to the specified NIC.
@@ -873,7 +895,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
} else {
for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) {
continue
}
if nic, ok := s.nics[route.NIC]; ok {
@@ -1082,6 +1104,28 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip
}
}
+// RegisterRestoredEndpoint records e as an endpoint that has been restored on
+// this stack.
+func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
+ s.mu.Lock()
+ s.resumableEndpoints = append(s.resumableEndpoints, e)
+ s.mu.Unlock()
+}
+
+// Resume restarts the stack after a restore. This must be called after the
+// entire system has been restored.
+func (s *Stack) Resume() {
+ // ResumableEndpoint.Resume() may call other methods on s, so we can't hold
+ // s.mu while resuming the endpoints.
+ s.mu.Lock()
+ eps := s.resumableEndpoints
+ s.resumableEndpoints = nil
+ s.mu.Unlock()
+ for _, e := range eps {
+ e.Resume(s)
+ }
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1161,3 +1205,13 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
}
return tcpip.ErrUnknownNICID
}
+
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() iptables.IPTables {
+ return s.tables
+}
+
+// SetIPTables sets the stack's iptables.
+func (s *Stack) SetIPTables(ipt iptables.IPTables) {
+ s.tables = ipt
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9d082bba4..4debd1eec 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -181,6 +181,10 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
+func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
+ return f.packetCount[int(intfAddr)%len(f.packetCount)]
+}
+
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
@@ -188,7 +192,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
- id: stack.NetworkEndpointID{addrWithPrefix.Address},
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
@@ -289,16 +293,75 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ return err
}
defer r.Release()
+ return send(r, payload)
+}
+func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Error("WritePacket failed:", err)
+ return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123)
+}
+
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ linkEP.Drain()
+ if err := sendTo(s, addr, payload); err != nil {
+ t.Error("sendTo failed:", err)
+ }
+ if got, want := linkEP.Drain(), 1; got != want {
+ t.Errorf("sendTo packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ linkEP.Drain()
+ if err := send(r, payload); err != nil {
+ t.Error("send failed:", err)
+ }
+ if got, want := linkEP.Drain(), 1; got != want {
+ t.Errorf("send packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := send(r, payload); gotErr != wantErr {
+ t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
+ t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte) + 1
+ testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+}
+
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we do NOT expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte)
+ testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+}
+
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) {
+ t.Helper()
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if got := fakeNet.PacketCount(localAddrByte); got != want {
+ t.Errorf("receive packet count: got = %d, want %d", got, want)
}
}
@@ -312,17 +375,20 @@ func TestNetworkSend(t *testing.T) {
t.Fatal("NewNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
}
// Make sure that the link-layer endpoint received the outbound packet.
- sendTo(t, s, "\x03", nil)
- if c := linkEP.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x03", linkEP, nil)
}
func TestNetworkSendMultiRoute(t *testing.T) {
@@ -360,24 +426,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
// Send a packet to an odd destination.
- sendTo(t, s, "\x05", nil)
-
- if c := linkEP1.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x05", linkEP1, nil)
// Send a packet to an even destination.
- sendTo(t, s, "\x06", nil)
-
- if c := linkEP2.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x06", linkEP2, nil)
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
@@ -439,10 +507,20 @@ func TestRoutes(t *testing.T) {
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
// Test routes to odd address.
testRoute(t, s, 0, "", "\x05", "\x01")
@@ -472,6 +550,10 @@ func TestRoutes(t *testing.T) {
}
func TestAddressRemoval(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
@@ -479,99 +561,285 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
- // Remove the address, then check that packet doesn't get delivered
- // anymore.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
}
-func TestDelayedRemovalDueToRoute(t *testing.T) {
+func TestAddressRemovalWithRouteHeld(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatal("CreateNIC failed:", err)
}
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- // Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 2 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSend(t, r, linkEP, nil)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
- // Remove the address, then check that packet is still deliverable
- // because the route is keeping the address alive.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
+}
- // Release the route, then check that packet is not deliverable anymore.
- r.Release()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+ t.Helper()
+ info, ok := s.NICInfo()[nicid]
+ if !ok {
+ t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ }
+ if len(addr) == 0 {
+ // No address given, verify that there is no address assigned to the NIC.
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ }
+ }
+ return
+ }
+ // Address given, verify the address is assigned to the NIC and no other
+ // address is.
+ found := false
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber {
+ if a.AddressWithPrefix.Address == addr {
+ found = true
+ } else {
+ t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
+ }
+ }
+ }
+ if !found {
+ t.Errorf("verify address: couldn't find %s on the NIC", addr)
+ }
+}
+
+func TestEndpointExpiration(t *testing.T) {
+ const (
+ localAddrByte byte = 0x01
+ remoteAddr tcpip.Address = "\x03"
+ noAddr tcpip.Address = ""
+ nicid tcpip.NICID = 1
+ )
+ localAddr := tcpip.Address([]byte{localAddrByte})
+
+ for _, promiscuous := range []bool{true, false} {
+ for _, spoofing := range []bool{true, false} {
+ t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+ buf[0] = localAddrByte
+
+ if promiscuous {
+ if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ }
+
+ if spoofing {
+ if err := s.SetSpoofing(nicid, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ }
+
+ // 1. No Address yet, send should only work for spoofing, receive for
+ // promiscuous mode.
+ //-----------------------
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 2. Add Address, everything should work.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 3. Remove the address, send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 4. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 5. Take a reference to the endpoint by getting a route. Verify that
+ // we can still send/receive, including sending using the route.
+ //-----------------------
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // 6. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ testSend(t, r, linkEP, nil)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 7. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // 8. Remove the route, sendTo/recv should still work.
+ //-----------------------
+ r.Release()
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 9. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+ })
+ }
}
}
@@ -583,9 +851,13 @@ func TestPromiscuousMode(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -593,22 +865,15 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
@@ -621,54 +886,120 @@ func TestPromiscuousMode(t *testing.T) {
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+func TestSpoofingWithAddress(t *testing.T) {
+ localAddr := tcpip.Address("\x01")
+ nonExistentLocalAddr := tcpip.Address("\x02")
+ dstAddr := tcpip.Address("\x03")
+
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
+ // Sending a packet works.
+ testSendTo(t, s, dstAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // FindRoute should also work with a local address that exists on the NIC.
+ r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != localAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet using the route works.
+ testSend(t, r, linkEP, nil)
}
-func TestAddressSpoofing(t *testing.T) {
- srcAddr := tcpip.Address("\x01")
+func TestSpoofingNoAddress(t *testing.T) {
+ nonExistentLocalAddr := tcpip.Address("\x01")
dstAddr := tcpip.Address("\x02")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
+ id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
+ // Sending a packet fails.
+ testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute)
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
if err := s.SetSpoofing(1, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- if r.LocalAddress != srcAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
+ // Sending a packet works.
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
@@ -806,16 +1137,20 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -824,9 +1159,52 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
t.Fatal("AddSubnet failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+}
+
+// Set the subnet, then check that CheckLocalAddress returns the correct NIC.
+func TestCheckLocalAddressForSubnet(t *testing.T) {
+ const nicID tcpip.NICID = 1
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}})
+ }
+
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
+
+ if err != nil {
+ t.Fatal("NewSubnet failed:", err)
+ }
+ if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddSubnet failed:", err)
+ }
+
+ // Loop over all subnet addresses and check them.
+ numOfAddresses := 1 << uint(8-subnet.Prefix())
+ if numOfAddresses < 1 || numOfAddresses > 255 {
+ t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
+ }
+ addr := []byte(subnet.ID())
+ for i := 0; i < numOfAddresses; i++ {
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID)
+ }
+ addr[0]++
+ }
+
+ // Trying the next address should fail since it is outside the subnet range.
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0)
}
}
@@ -839,16 +1217,20 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -856,10 +1238,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatal("AddSubnet failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
}
func TestNetworkOptions(t *testing.T) {
@@ -969,12 +1348,18 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
if behavior == stack.CanBePrimaryEndpoint {
- addressWithPrefix := tcpip.AddressWithPrefix{address, addrLen * 8}
- if err := s.AddAddressWithPrefixAndOptions(1, fakeNetNumber, addressWithPrefix, behavior); err != nil {
- t.Fatal("AddAddressWithPrefixAndOptions failed:", err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: addrLen * 8,
+ },
+ }
+ if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil {
+ t.Fatal("AddProtocolAddressWithOptions failed:", err)
}
// Remember the address/prefix.
- primaryAddrAdded[addressWithPrefix] = struct{}{}
+ primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
t.Fatal("AddAddressWithOptions failed:", err)
@@ -1024,20 +1409,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
{"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116},
} {
t.Run(tc.name, func(t *testing.T) {
- addressWithPrefix := tcpip.AddressWithPrefix{tc.address, tc.prefixLen}
-
- if err := s.AddAddressWithPrefix(1, fakeNetNumber, addressWithPrefix); err != nil {
- t.Fatal("AddAddressWithPrefix failed:", err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tc.address,
+ PrefixLen: tc.prefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
+ t.Fatal("AddProtocolAddress failed:", err)
}
// Check that we get the right initial address and prefix length.
if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
t.Fatal("GetMainNICAddress failed:", err)
- } else if gotAddressWithPrefix != addressWithPrefix {
- t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, addressWithPrefix)
+ } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix)
}
- if err := s.RemoveAddress(1, addressWithPrefix.Address); err != nil {
+ if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
@@ -1102,7 +1492,7 @@ func TestAddAddress(t *testing.T) {
verifyAddresses(t, expectedAddresses, gotAddresses)
}
-func TestAddAddressWithPrefix(t *testing.T) {
+func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
@@ -1116,14 +1506,17 @@ func TestAddAddressWithPrefix(t *testing.T) {
expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
for _, addrLen := range addrLenRange {
for _, prefixLen := range prefixLenRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithPrefix(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}); err != nil {
- t.Errorf("AddAddressWithPrefix(address=%s, prefixLen=%d) failed: %s", address, prefixLen, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addrGen.next(addrLen),
+ PrefixLen: prefixLen,
+ },
}
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen},
- })
+ if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil {
+ t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
+ }
+ expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
@@ -1160,7 +1553,7 @@ func TestAddAddressWithOptions(t *testing.T) {
verifyAddresses(t, expectedAddresses, gotAddresses)
}
-func TestAddAddressWithPrefixAndOptions(t *testing.T) {
+func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
@@ -1176,14 +1569,17 @@ func TestAddAddressWithPrefixAndOptions(t *testing.T) {
for _, addrLen := range addrLenRange {
for _, prefixLen := range prefixLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithPrefixAndOptions(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}, behavior); err != nil {
- t.Fatalf("AddAddressWithPrefixAndOptions(address=%s, prefixLen=%d, behavior=%d) failed: %s", address, prefixLen, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addrGen.next(addrLen),
+ PrefixLen: prefixLen,
+ },
}
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen},
- })
+ if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ }
+ expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
@@ -1196,15 +1592,19 @@ func TestNICStats(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ t.Fatal("CreateNIC failed: ", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
}
// Route all packets for address \x01 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\xff", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
// Send a packet to address 1.
buf := buffer.NewView(30)
@@ -1219,7 +1619,9 @@ func TestNICStats(t *testing.T) {
payload := buffer.NewView(10)
// Write a packet out via the address for NIC 1
- sendTo(t, s, "\x01", payload)
+ if err := sendTo(s, "\x01", payload); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
want := uint64(linkEP1.Drain())
if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
@@ -1253,9 +1655,13 @@ func TestNICForwarding(t *testing.T) {
}
// Route all packets to address 3 to NIC 2.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
+ }
// Send a packet to address 3.
buf := buffer.NewView(30)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index b418db046..5335897f5 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -64,7 +65,7 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
@@ -78,10 +79,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions)
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
-func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -104,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
@@ -200,6 +206,13 @@ func (f *fakeTransportEndpoint) State() uint32 {
func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
}
+func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
+ return iptables.IPTables{}, nil
+}
+
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
+}
+
type fakeTransportGoodOption bool
type fakeTransportBadOption bool
@@ -271,7 +284,13 @@ func TestTransportReceive(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
@@ -327,7 +346,13 @@ func TestTransportControlReceive(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
@@ -393,7 +418,13 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("AddAddress failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
// Create endpoint and bind it.
wq := waiter.Queue{}
@@ -484,10 +515,20 @@ func TestTransportForwarding(t *testing.T) {
// Route all packets to address 3 to NIC 2 and all packets to address
// 1 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- {"\x01", "\xff", "\x00", 1},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ })
+ }
wq := waiter.Queue{}
ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4208c0303..8f9b86cce 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "math/bits"
"reflect"
"strconv"
"strings"
@@ -39,6 +40,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -144,8 +146,17 @@ type Address string
type AddressMask string
// String implements Stringer.
-func (a AddressMask) String() string {
- return Address(a).String()
+func (m AddressMask) String() string {
+ return Address(m).String()
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (m AddressMask) Prefix() int {
+ p := 0
+ for _, b := range []byte(m) {
+ p += bits.LeadingZeros8(^b)
+ }
+ return p
}
// Subnet is a subnet defined by its address and mask.
@@ -167,6 +178,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) {
return Subnet{a, m}, nil
}
+// String implements Stringer.
+func (s Subnet) String() string {
+ return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
+}
+
// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {
@@ -189,28 +205,13 @@ func (s *Subnet) ID() Address {
// Bits returns the number of ones (network bits) and zeros (host bits) in the
// subnet mask.
func (s *Subnet) Bits() (ones int, zeros int) {
- for _, b := range []byte(s.mask) {
- for i := uint(0); i < 8; i++ {
- if b&(1<<i) == 0 {
- zeros++
- } else {
- ones++
- }
- }
- }
- return
+ ones = s.mask.Prefix()
+ return ones, len(s.mask)*8 - ones
}
// Prefix returns the number of bits before the first host bit.
func (s *Subnet) Prefix() int {
- for i, b := range []byte(s.mask) {
- for j := 7; j >= 0; j-- {
- if b&(1<<uint(j)) == 0 {
- return i*8 + 7 - j
- }
- }
- }
- return len(s.mask) * 8
+ return s.mask.Prefix()
}
// Mask returns the subnet mask.
@@ -328,12 +329,12 @@ type Endpoint interface {
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
- Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error)
+ Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek([][]byte) (uintptr, ControlMessages, *Error)
+ Peek([][]byte) (int64, ControlMessages, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -352,6 +353,9 @@ type Endpoint interface {
// ErrAddressFamilyNotSupported must be returned.
Connect(address FullAddress) *Error
+ // Disconnect disconnects the endpoint from its peer.
+ Disconnect() *Error
+
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
Shutdown(flags ShutdownFlags) *Error
@@ -403,6 +407,9 @@ type Endpoint interface {
//
// NOTE: This method is a no-op for sockets other than TCP.
ModerateRecvBuf(copied int)
+
+ // IPTables returns the iptables for this endpoint's stack.
+ IPTables() (iptables.IPTables, error)
}
// WriteOptions contains options for Endpoint.Write.
@@ -563,13 +570,8 @@ type BroadcastOption int
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
type Route struct {
- // Destination is the address that must be matched against the masked
- // target address to check if this row is viable.
- Destination Address
-
- // Mask specifies which bits of the Destination and the target address
- // must match for this row to be viable.
- Mask AddressMask
+ // Destination must contain the target address for this row to be viable.
+ Destination Subnet
// Gateway is the gateway to be used if this row is viable.
Gateway Address
@@ -578,25 +580,15 @@ type Route struct {
NIC NICID
}
-// Match determines if r is viable for the given destination address.
-func (r *Route) Match(addr Address) bool {
- if len(addr) != len(r.Destination) {
- return false
- }
-
- // Using header.Ipv4Broadcast would introduce an import cycle, so
- // we'll use a literal instead.
- if addr == "\xff\xff\xff\xff" {
- return true
- }
-
- for i := 0; i < len(r.Destination); i++ {
- if (addr[i] & r.Mask[i]) != r.Destination[i] {
- return false
- }
+// String implements the fmt.Stringer interface.
+func (r Route) String() string {
+ var out strings.Builder
+ fmt.Fprintf(&out, "%s", r.Destination)
+ if len(r.Gateway) > 0 {
+ fmt.Fprintf(&out, " via %s", r.Gateway)
}
-
- return true
+ fmt.Fprintf(&out, " nic %d", r.NIC)
+ return out.String()
}
// LinkEndpointID represents a data link layer endpoint.
@@ -1068,6 +1060,11 @@ type AddressWithPrefix struct {
PrefixLen int
}
+// String implements the fmt.Stringer interface.
+func (a AddressWithPrefix) String() string {
+ return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen)
+}
+
// ProtocolAddress is an address and the network protocol it is associated
// with.
type ProtocolAddress struct {
@@ -1078,11 +1075,13 @@ type ProtocolAddress struct {
AddressWithPrefix AddressWithPrefix
}
-// danglingEndpointsMu protects access to danglingEndpoints.
-var danglingEndpointsMu sync.Mutex
+var (
+ // danglingEndpointsMu protects access to danglingEndpoints.
+ danglingEndpointsMu sync.Mutex
-// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
-var danglingEndpoints = make(map[Endpoint]struct{})
+ // danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+ danglingEndpoints = make(map[Endpoint]struct{})
+)
// GetDanglingEndpoints returns all dangling endpoints.
func GetDanglingEndpoints() []Endpoint {
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index ebb1c1b56..fb3a0a5ee 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -60,12 +60,12 @@ func TestSubnetBits(t *testing.T) {
}{
{"\x00", 0, 8},
{"\x00\x00", 0, 16},
- {"\x36", 4, 4},
- {"\x5c", 4, 4},
- {"\x5c\x5c", 8, 8},
- {"\x5c\x36", 8, 8},
- {"\x36\x5c", 8, 8},
- {"\x36\x36", 8, 8},
+ {"\x36", 0, 8},
+ {"\x5c", 0, 8},
+ {"\x5c\x5c", 0, 16},
+ {"\x5c\x36", 0, 16},
+ {"\x36\x5c", 0, 16},
+ {"\x36\x36", 0, 16},
{"\xff", 8, 0},
{"\xff\xff", 16, 0},
}
@@ -122,26 +122,6 @@ func TestSubnetCreation(t *testing.T) {
}
}
-func TestRouteMatch(t *testing.T) {
- tests := []struct {
- d Address
- m AddressMask
- a Address
- want bool
- }{
- {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
- {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
- {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
- {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
- }
- for _, tt := range tests {
- r := Route{Destination: tt.d, Mask: tt.m}
- if got := r.Match(tt.a); got != tt.want {
- t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want)
- }
- }
-}
-
func TestAddressString(t *testing.T) {
for _, want := range []string{
// Taken from stdlib.
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 62182a3e6..d78a162b8 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index ba6671c26..451d3880e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -130,6 +131,11 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
+}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -199,7 +205,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -301,11 +307,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -422,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 99b8c4093..c587b96b6 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -63,7 +63,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
if e.state != stateBound && e.state != stateConnected {
return
@@ -73,7 +78,7 @@ func (e *endpoint) afterLoad() {
if e.state == stateConnected {
e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
if err != nil {
- panic(*err)
+ panic(err)
}
e.id.LocalAddress = e.route.LocalAddress
@@ -85,6 +90,6 @@ func (e *endpoint) afterLoad() {
e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
if err != nil {
- panic(*err)
+ panic(err)
}
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index bc4b255b4..7241f6c19 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b633cd9d8..13e17e2a6 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -168,6 +169,11 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+ return ep.stack.IPTables(), nil
+}
+
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
if !ep.associated {
@@ -201,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -305,7 +311,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
// We may need to resolve the route (match a link layer address to the
// network address). If that requires blocking (e.g. to use ARP),
// return a channel on which the caller can wait.
@@ -335,24 +341,24 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintpt
return 0, nil, tcpip.ErrUnknownProtocol
}
- return uintptr(len(payloadBytes)), nil, nil
+ return int64(len(payloadBytes)), nil, nil
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect implements tcpip.Endpoint.Connect.
func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
if ep.closed {
return tcpip.ErrInvalidEndpointState
}
@@ -484,7 +490,7 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return nil
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index cb5534d90..168953dec 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -63,19 +63,23 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- // StackFromEnv is a stack used specifically for save/restore.
- ep.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(ep)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (ep *endpoint) Resume(s *stack.Stack) {
+ ep.stack = s
- // If the endpoint is connected, re-connect via the save/restore stack.
+ // If the endpoint is connected, re-connect.
if ep.connected {
var err *tcpip.Error
ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
if err != nil {
- panic(*err)
+ panic(err)
}
}
- // If the endpoint is bound, re-bind via the save/restore stack.
+ // If the endpoint is bound, re-bind.
if ep.bound {
if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
@@ -83,6 +87,6 @@ func (ep *endpoint) afterLoad() {
}
if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
- panic(*err)
+ panic(err)
}
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 4cd25e8e2..1ee1a53f8 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -48,6 +48,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 52fd1bfa3..e9c5099ea 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -96,6 +96,17 @@ type listenContext struct {
hasher hash.Hash
v6only bool
netProto tcpip.NetworkProtocolNumber
+ // pendingMu protects pendingEndpoints. This should only be accessed
+ // by the listening endpoint's worker goroutine.
+ //
+ // Lock Ordering: listenEP.workerMu -> pendingMu
+ pendingMu sync.Mutex
+ // pending is used to wait for all pendingEndpoints to finish when
+ // a socket is closed.
+ pending sync.WaitGroup
+ // pendingEndpoints is a map of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -133,14 +144,15 @@ func decSynRcvdCount() {
}
// newListenContext creates a new listen context.
-func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stack,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6only: v6only,
- netProto: netProto,
- listenEP: listenEP,
+ stack: stk,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ listenEP: listenEP,
+ pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
rand.Read(l.nonce[0][:])
@@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
+ // listenEP is nil when listenContext is used by tcp.Forwarder.
+ if l.listenEP != nil {
+ l.listenEP.mu.Lock()
+ if l.listenEP.state != StateListen {
+ l.listenEP.mu.Unlock()
+ return nil, tcpip.ErrConnectionAborted
+ }
+ l.addPendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
@@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if err := h.execute(); err != nil {
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
return nil, err
}
ep.mu.Lock()
@@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return ep, nil
}
+func (l *listenContext) addPendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ l.pendingEndpoints[n.id] = n
+ l.pending.Add(1)
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) removePendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ delete(l.pendingEndpoints, n.id)
+ l.pending.Done()
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) closeAllPendingEndpoints() {
+ l.pendingMu.Lock()
+ for _, n := range l.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ l.pendingMu.Unlock()
+ l.pending.Wait()
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state, the new endpoint is closed
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
- e.mu.RLock()
+ e.mu.Lock()
state := e.state
- e.mu.RUnlock()
+ e.pendingAccepted.Add(1)
+ defer e.pendingAccepted.Done()
+ acceptedChan := e.acceptedChan
+ e.mu.Unlock()
if state == StateListen {
- e.acceptedChan <- n
+ acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
n.Close()
@@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
return
}
-
+ ctx.removePendingEndpoint(n)
e.deliverAccepted(n)
}
@@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
@@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
e.state = StateClose
+ // close any endpoints in SYN-RCVD state.
+ ctx.closeAllPendingEndpoints()
+
// Do cleanup if needed.
e.completeWorkerLocked()
@@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
}()
- e.mu.Lock()
- v6only := e.v6only
- e.mu.Unlock()
-
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
-
s := sleep.Sleeper{}
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
@@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.handleListenSegment(ctx, s)
s.decRef()
}
- synRcvdCount.pending.Wait()
close(e.drainDone)
<-e.undrain
}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index d9f79e8c5..c54610a87 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -570,3 +570,89 @@ func TestV4AcceptOnV4(t *testing.T) {
// Test acceptance.
testV4Accept(t, c)
}
+
+func testV4ListenClose(t *testing.T, c *context.Context) {
+ // Set the SynRcvd threshold to zero to force a syn cookie based accept
+ // to happen.
+ saved := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = saved
+ }()
+ tcp.SynRcvdCountThreshold = 0
+ const n = uint16(32)
+
+ // Start listening.
+ if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ irs := seqnum.Value(789)
+ for i := uint16(0); i < n; i++ {
+ // Send a SYN request.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + i,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Each of these ACK's will cause a syn-cookie based connection to be
+ // accepted and delivered to the listening endpoint.
+ for i := uint16(0); i < n; i++ {
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ nep.Close()
+ c.EP.Close()
+}
+
+func TestV4ListenCloseOnV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4ListenClose(t, c)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index cc49c8272..ac927569a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tmutex"
@@ -361,6 +362,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // pendingAccepted is a synchronization primitive used to track number
+ // of connections that are queued up to be delivered to the accepted
+ // channel. We use this to ensure that all goroutines blocked on writing
+ // to the acceptedChan below terminate before we close acceptedChan.
+ pendingAccepted sync.WaitGroup `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -374,7 +381,11 @@ type endpoint struct {
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
- // The goroutine undrain notification channel.
+ // The goroutine undrain notification channel. This is currently used as
+ // a way to block the worker goroutines. Today nothing closes/writes
+ // this channel and this causes any goroutines waiting on this to just
+ // block. This is used during save/restore to prevent worker goroutines
+ // from mutating state as it's being saved.
undrain chan struct{} `state:"nosave"`
// probe if not nil is invoked on every received segment. It is passed
@@ -574,6 +585,34 @@ func (e *endpoint) Close() {
e.mu.Unlock()
}
+// closePendingAcceptableConnections closes all connections that have completed
+// handshake but not yet been delivered to the application.
+func (e *endpoint) closePendingAcceptableConnectionsLocked() {
+ done := make(chan struct{})
+ // Spin a goroutine up as ranging on e.acceptedChan will just block when
+ // there are no more connections in the channel. Using a non-blocking
+ // select does not work as it can potentially select the default case
+ // even when there are pending writes but that are not yet written to
+ // the channel.
+ go func() {
+ defer close(done)
+ for n := range e.acceptedChan {
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
+ n.Close()
+ }
+ }()
+ // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
+ // endpoints which have completed handshake but are not yet written to
+ // the e.acceptedChan. We wait here till the goroutine above can drain
+ // all such connections from e.acceptedChan.
+ e.pendingAccepted.Wait()
+ close(e.acceptedChan)
+ <-done
+ e.acceptedChan = nil
+}
+
// cleanupLocked frees all resources associated with the endpoint. It is called
// after Close() is called and the worker goroutine (if any) is done with its
// work.
@@ -581,14 +620,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
- close(e.acceptedChan)
- for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
- n.Close()
- }
- e.acceptedChan = nil
+ e.closePendingAcceptableConnectionsLocked()
}
e.workerCleanup = false
@@ -683,6 +715,11 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
@@ -740,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
return v, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
-
- e.mu.RLock()
- defer e.mu.RUnlock()
-
+// isEndpointWritableLocked checks if a given endpoint is writable
+// and also returns the number of bytes that can be written at this
+// moment. If the endpoint is not writable then it returns an error
+// indicating the reason why it's not writable.
+// Caller must hold e.mu and e.sndBufMu
+func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, nil, e.hardError
+ return 0, e.hardError
default:
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
}
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
- }
-
- e.sndBufMu.Lock()
-
// Check if the connection has already been closed for sends.
if e.sndClosed {
- e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
- // Check against the limit.
avail := e.sndBufSize - e.sndBufUsed
if avail <= 0 {
+ return 0, tcpip.ErrWouldBlock
+ }
+ return avail, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrWouldBlock
+ e.mu.RUnlock()
+ return 0, nil, err
}
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
+ }
+
+ // Copy in memory without holding sndBufMu so that worker goroutine can
+ // make progress independent of this operation.
v, perr := p.Get(avail)
if perr != nil {
- e.sndBufMu.Unlock()
return 0, nil, perr
}
- l := len(v)
- s := newSegmentFromView(&e.route, e.id, v)
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a
+ // write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due to a
+ // simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
// Add data to the send queue.
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
e.sndBufUsed += l
e.sndBufInQueue += seqnum.Size(l)
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
if e.workMu.TryLock() {
// Do the work inline.
@@ -803,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return uintptr(l), nil, nil
+ return int64(l), nil, nil
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -835,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
// Make a copy of vec so we can modify the slide headers.
vec = append([][]byte(nil), vec...)
- var num uintptr
-
+ var num int64
for s := e.rcvList.Front(); s != nil; s = s.Next() {
views := s.data.Views()
@@ -855,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
n := copy(vec[0], v)
v = v[n:]
vec[0] = vec[0][n:]
- num += uintptr(n)
+ num += int64(n)
}
}
}
@@ -1277,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == "\x00\x00\x00\x00" {
+ if addr.Addr == header.IPv4Any {
addr.Addr = ""
}
}
@@ -1291,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" && addr.Port == 0 {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
return e.connect(addr, true, true)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b3f0f6c5d..831389ec7 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -165,7 +165,12 @@ func (e *endpoint) loadState(state EndpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
@@ -197,14 +202,13 @@ func (e *endpoint) afterLoad() {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
bind()
if len(e.connectingAddress) == 0 {
+ e.connectingAddress = e.id.RemoteAddress
// This endpoint is accepted by netstack but not yet by
// the app. If the endpoint is IPv6 but the remote
// address is IPv4, we need to connect as IPv6 so that
// dual-stack mode can be properly activated.
if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
- } else {
- e.connectingAddress = e.id.RemoteAddress
}
}
// Reset the scoreboard to reinitialize the sack information as
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 0fee7ab72..735edfe55 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -39,6 +39,28 @@ const (
nDupAckThreshold = 3
)
+// ccState indicates the current congestion control state for this sender.
+type ccState int
+
+const (
+ // Open indicates that the sender is receiving acks in order and
+ // no loss or dupACK's etc have been detected.
+ Open ccState = iota
+ // RTORecovery indicates that an RTO has occurred and the sender
+ // has entered an RTO based recovery phase.
+ RTORecovery
+ // FastRecovery indicates that the sender has entered FastRecovery
+ // based on receiving nDupAck's. This state is entered only when
+ // SACK is not in use.
+ FastRecovery
+ // SACKRecovery indicates that the sender has entered SACK based
+ // recovery.
+ SACKRecovery
+ // Disorder indicates the sender either received some SACK blocks
+ // or dupACK's.
+ Disorder
+)
+
// congestionControl is an interface that must be implemented by any supported
// congestion control algorithm.
type congestionControl interface {
@@ -138,6 +160,9 @@ type sender struct {
// maxSentAck is the maxium acknowledgement actually sent.
maxSentAck seqnum.Value
+ // state is the current state of congestion control for this endpoint.
+ state ccState
+
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
}
@@ -435,6 +460,7 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveFastRecovery()
}
+ s.state = RTORecovery
s.cc.HandleRTOExpired()
// Mark the next segment to be sent as the first unacknowledged one and
@@ -638,7 +664,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
segEnd = seg.sequenceNumber.Add(1)
// Transition to FIN-WAIT1 state since we're initiating an active close.
s.ep.mu.Lock()
- s.ep.state = StateFinWait1
+ switch s.ep.state {
+ case StateCloseWait:
+ // We've already received a FIN and are now sending our own. The
+ // sender is now awaiting a final ACK for this FIN.
+ s.ep.state = StateLastAck
+ default:
+ s.ep.state = StateFinWait1
+ }
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
@@ -820,9 +853,11 @@ func (s *sender) enterFastRecovery() {
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
if s.ep.sackPermitted {
+ s.state = SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
return
}
+ s.state = FastRecovery
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
@@ -981,6 +1016,7 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
s.fr.highRxt = s.sndUna - 1
// Do run SetPipe() to calculate the outstanding segments.
s.SetPipe()
+ s.state = Disorder
return false
}
@@ -1112,6 +1148,9 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// window based on the number of acknowledged packets.
if !s.fr.active {
s.cc.Update(originalOutstanding - s.outstanding)
+ if s.fr.last.LessThan(s.sndUna) {
+ s.state = Open
+ }
}
// It is possible for s.outstanding to drop below zero if we get
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 915a98047..f79b8ec5f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2874,15 +2874,11 @@ func makeStack() (*stack.Stack, *tcpip.Error) {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index bcc0f3e28..272481aa0 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -168,15 +168,11 @@ func New(t *testing.T, mtu uint32) *Context {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 6dac66b50..ac2666f69 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 91f89a781..ac5905772 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -172,6 +173,11 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
+}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -241,13 +247,13 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto, err := e.checkV4Mapped(&addr, false)
- if err != nil {
- return stack.Route{}, 0, 0, err
+func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+ localAddr := e.id.LocalAddress
+ if isBroadcastOrMulticast(localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
}
- localAddr := e.id.LocalAddress
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
if nicid == 0 {
nicid = e.multicastNICID
@@ -260,14 +266,14 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac
// Find a route to the desired destination.
r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
- return stack.Route{}, 0, 0, err
+ return stack.Route{}, 0, err
}
- return r, nicid, netProto, nil
+ return r, nicid, nil
}
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -336,7 +342,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- r, _, _, err := e.connectRoute(nicid, *to)
+ netProto, err := e.checkV4Mapped(to, false)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ r, _, err := e.connectRoute(nicid, *to, netProto)
if err != nil {
return 0, nil, err
}
@@ -368,11 +379,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -442,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
nicID := v.NIC
- if v.InterfaceAddr == header.IPv4Any {
+
+ // The interface address is considered not-set if it is empty or contains
+ // all-zeros. The former represent the zero-value in golang, the latter the
+ // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
+ allZeros := header.IPv4Any
+ if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
if nicID == 0 {
r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
if err == nil {
@@ -686,7 +702,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == "\x00\x00\x00\x00" {
+ if addr.Addr == header.IPv4Any {
addr.Addr = ""
}
@@ -705,7 +721,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
-func (e *endpoint) disconnect() *tcpip.Error {
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -740,8 +757,9 @@ func (e *endpoint) disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" {
- return e.disconnect()
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
}
if addr.Port == 0 {
// We don't support connecting to port zero.
@@ -770,7 +788,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- r, nicid, netProto, err := e.connectRoute(nicid, addr)
+ r, nicid, err := e.connectRoute(nicid, addr, netProto)
if err != nil {
return err
}
@@ -906,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
nicid := addr.NIC
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
+ if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ // A local unicast address was specified, verify that it's valid.
nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicid == 0 {
return tcpip.ErrBadLocalAddress
@@ -1056,3 +1074,7 @@ func (e *endpoint) State() uint32 {
// TODO(b/112063468): Translate internal state to values returned by Linux.
return 0
}
+
+func isBroadcastOrMulticast(a tcpip.Address) bool {
+ return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 18e786397..5cbb56120 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -64,7 +64,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
for _, m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
@@ -90,9 +95,10 @@ func (e *endpoint) afterLoad() {
if e.state == stateConnected {
e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
- panic(*err)
+ panic(err)
}
- } else if len(e.id.LocalAddress) != 0 { // stateBound
+ } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound
+ // A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
@@ -105,6 +111,6 @@ func (e *endpoint) afterLoad() {
e.id.LocalPort = 0
e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
if err != nil {
- panic(*err)
+ panic(err)
}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 56c285f88..9da6edce2 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "fmt"
"math"
"math/rand"
"testing"
@@ -34,13 +35,19 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// Addresses and ports used for testing. It is recommended that tests stick to
+// using these addresses as it allows using the testFlow helper.
+// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
+// represents the remote endpoint.
const (
+ v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
+ testV4MappedAddr = v4MappedAddrPrefix + testAddr
+ multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
+ broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
+ v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
@@ -48,7 +55,7 @@ const (
testPort = 4096
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- multicastPort = 1234
+ broadcastAddr = header.IPv4Broadcast
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -56,6 +63,205 @@ const (
defaultMTU = 65536
)
+// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
+// a packet header. These values are used to populate a header or verify one.
+// Note that because they are used in packet headers, the addresses are never in
+// a V4-mapped format.
+type header4Tuple struct {
+ srcAddr tcpip.FullAddress
+ dstAddr tcpip.FullAddress
+}
+
+// testFlow implements a helper type used for sending and receiving test
+// packets. A given test flow value defines 1) the socket endpoint used for the
+// test and 2) the type of packet send or received on the endpoint. E.g., a
+// multicastV6Only flow is a V6 multicast packet passing through a V6-only
+// endpoint. The type provides helper methods to characterize the flow (e.g.,
+// isV4) as well as return a proper header4Tuple for it.
+type testFlow int
+
+const (
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+)
+
+func (flow testFlow) String() string {
+ switch flow {
+ case unicastV4:
+ return "unicastV4"
+ case unicastV6:
+ return "unicastV6"
+ case unicastV6Only:
+ return "unicastV6Only"
+ case unicastV4in6:
+ return "unicastV4in6"
+ case multicastV4:
+ return "multicastV4"
+ case multicastV6:
+ return "multicastV6"
+ case multicastV6Only:
+ return "multicastV6Only"
+ case multicastV4in6:
+ return "multicastV4in6"
+ case broadcast:
+ return "broadcast"
+ case broadcastIn6:
+ return "broadcastIn6"
+ default:
+ return "unknown"
+ }
+}
+
+// packetDirection explains if a flow is incoming (read) or outgoing (write).
+type packetDirection int
+
+const (
+ incoming packetDirection = iota
+ outgoing
+)
+
+// header4Tuple returns the header4Tuple for the given flow and direction. Note
+// that the tuple contains no mapped addresses as those only exist at the socket
+// level but not at the packet header level.
+func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
+ var h header4Tuple
+ if flow.isV4() {
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastAddr
+ } else if flow.isBroadcast() {
+ h.dstAddr.Addr = broadcastAddr
+ }
+ } else { // IPv6
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastV6Addr
+ }
+ }
+ return h
+}
+
+func (flow testFlow) getMcastAddr() tcpip.Address {
+ if flow.isV4() {
+ return multicastAddr
+ }
+ return multicastV6Addr
+}
+
+// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
+// if it is applicable to the flow.
+func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
+ if flow.isMapped() {
+ return v4MappedAddrPrefix + v4Addr
+ }
+ return v4Addr
+}
+
+// netProto returns the protocol number used for the network packet.
+func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
+ if flow.isV4() {
+ return ipv4.ProtocolNumber
+ }
+ return ipv6.ProtocolNumber
+}
+
+// sockProto returns the protocol number used when creating the socket
+// endpoint for this flow.
+func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
+ switch flow {
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ return ipv6.ProtocolNumber
+ case unicastV4, multicastV4, broadcast:
+ return ipv4.ProtocolNumber
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
+ if flow.isV4() {
+ return checker.IPv4
+ }
+ return checker.IPv6
+}
+
+func (flow testFlow) isV6() bool { return !flow.isV4() }
+func (flow testFlow) isV4() bool {
+ return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
+}
+
+func (flow testFlow) isV6Only() bool {
+ switch flow {
+ case unicastV6Only, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMulticast() bool {
+ switch flow {
+ case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isBroadcast() bool {
+ switch flow {
+ case broadcast, broadcastIn6:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMapped() bool {
+ switch flow {
+ case unicastV4in6, multicastV4in6, broadcastIn6:
+ return true
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -65,12 +271,9 @@ type testContext struct {
wq waiter.Queue
}
-type headers struct {
- srcPort uint16
- dstPort uint16
-}
-
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
+ t.Helper()
+
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(256, mtu, "")
@@ -91,15 +294,11 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
@@ -117,51 +316,54 @@ func (c *testContext) cleanup() {
}
}
-func (c *testContext) createV6Endpoint(v6only bool) {
+func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
+ c.t.Helper()
+
var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
+ c.t.Fatal("NewEndpoint failed: ", err)
}
+}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
+func (c *testContext) createEndpointForFlow(flow testFlow) {
+ c.t.Helper()
+
+ c.createEndpoint(flow.sockProto())
+ if flow.isV6Only() {
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ } else if flow.isBroadcast() {
+ if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
}
}
-func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
+// getPacketAndVerify reads a packet from the link endpoint and verifies the
+// header against expected values from the given test flow. In addition, it
+// calls any extra checker functions provided.
+func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
+ c.t.Helper()
+
select {
case p := <-c.linkEP.C:
- if p.Proto != protocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
- var srcAddr, dstAddr tcpip.Address
- switch protocolNumber {
- case ipv4.ProtocolNumber:
- checkerFn = checker.IPv4
- srcAddr, dstAddr = stackAddr, testAddr
- if multicast {
- dstAddr = multicastAddr
- }
- case ipv6.ProtocolNumber:
- checkerFn = checker.IPv6
- srcAddr, dstAddr = stackV6Addr, testV6Addr
- if multicast {
- dstAddr = multicastV6Addr
- }
- default:
- c.t.Fatalf("unknown protocol %d", protocolNumber)
- }
- checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
+ h := flow.header4Tuple(outgoing)
+ checkers := append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
return b
case <-time.After(2 * time.Second):
@@ -171,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult
return nil
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers) {
+// injectPacket creates a packet of the given flow and with the given payload,
+// and injects it into the link endpoint.
+func (c *testContext) injectPacket(flow testFlow, payload []byte) {
+ c.t.Helper()
+
+ h := flow.header4Tuple(incoming)
+ if flow.isV4() {
+ c.injectV4Packet(payload, &h)
+ } else {
+ c.injectV6Packet(payload, &h)
+ }
+}
+
+// injectV6Packet creates a V6 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -182,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -205,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-func (c *testContext) sendPacket(payload []byte, h *headers) {
+// injectV6Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -217,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) {
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: testAddr,
- DstAddr: stackAddr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -253,7 +472,7 @@ func TestBindPortReuse(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
var eps [5]tcpip.Endpoint
reusePortOpt := tcpip.ReusePortOption(1)
@@ -296,9 +515,9 @@ func TestBindPortReuse(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort + port,
- dstPort: stackPort,
+ c.injectV6Packet(payload, &header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
})
var addr tcpip.FullAddress
@@ -333,13 +552,14 @@ func TestBindPortReuse(t *testing.T) {
}
}
-func testV4Read(c *testContext) {
- // Send a packet.
+// testRead sends a packet of the given test flow into the stack by injecting it
+// into the link endpoint. It then reads it from the UDP endpoint and verifies
+// its correctness.
+func testRead(c *testContext, flow testFlow) {
+ c.t.Helper()
+
payload := newPayload()
- c.sendPacket(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
+ c.injectPacket(flow, payload)
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
@@ -363,8 +583,9 @@ func testV4Read(c *testContext) {
}
// Check the peer address.
- if addr.Addr != testAddr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ h := flow.header4Tuple(incoming)
+ if addr.Addr != h.srcAddr.Addr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr)
}
// Check the payload.
@@ -377,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
t.Fatalf("ep.Bind(...) failed: %v", err)
@@ -388,7 +609,7 @@ func TestBindReservedPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
@@ -447,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -455,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil {
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -485,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV6ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- // Send a packet.
- payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.EventIn)
- defer c.wq.EventUnregister(&we)
-
- var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- // Wait for data to become available.
- select {
- case <-ch:
- v, _, err = c.ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for data")
- }
- }
-
- // Check the peer address.
- if addr.Addr != testV6Addr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
- }
-
- // Check the payload.
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
- }
+ // Test acceptance.
+ testRead(c, unicastV6)
}
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- // Create v4 UDP endpoint.
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpointForFlow(unicastV4)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -555,62 +736,123 @@ func TestV4ReadOnV4(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4)
}
-func testV4Write(c *testContext) uint16 {
- // Write to V4 mapped address.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
+// address and receive data sent to that address.
+func TestReadOnBoundToMulticast(t *testing.T) {
+ // FIXME(b/128189410): multicastV4in6 currently doesn't work as
+ // AddMembershipOption doesn't handle V4in6 addresses.
+ for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to multicast address.
+ mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ // Join multicast group.
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
+
+ testRead(c, flow)
+ })
}
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
+// address and receive broadcast data on it.
+func TestV4ReadOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to broadcast address.
+ bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// testFailingWrite sends a packet of the given test flow into the UDP endpoint
+// and verifies it fails with the provided error code.
+func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
+ c.t.Helper()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+
+ payload := buffer.View(newPayload())
+ _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ })
+ if gotErr != wantErr {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
+}
- return udp.SourcePort()
+// testWrite sends a packet of the given test flow from the UDP endpoint to the
+// flow's destination address:port. It then receives it from the link endpoint
+// and verifies its correctness including any additional checker functions
+// provided.
+func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, true, checkers...)
}
-func testV6Write(c *testContext) uint16 {
- // Write to v6 address.
+// testWriteWithoutDestination sends a packet of the given test flow from the
+// UDP endpoint without giving a destination address:port. It then receives it
+// from the link endpoint and verifies its correctness including any additional
+// checker functions provided.
+func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, false, checkers...)
+}
+
+func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+
+ writeOpts := tcpip.WriteOptions{}
+ if setDest {
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+ writeOpts = tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ }
+ }
payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
- if n != uintptr(len(payload)) {
+ if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
+ // Received the packet and check the payload.
+ b := c.getPacketAndVerify(flow, checkers...)
+ var udp header.UDP
+ if flow.isV4() {
+ udp = header.UDP(header.IPv4(b).Payload())
+ } else {
+ udp = header.UDP(header.IPv6(b).Payload())
+ }
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
@@ -619,8 +861,10 @@ func testV6Write(c *testContext) uint16 {
}
func testDualWrite(c *testContext) uint16 {
- v4Port := testV4Write(c)
- v6Port := testV6Write(c)
+ c.t.Helper()
+
+ v4Port := testWrite(c, unicastV4in6)
+ v6Port := testWrite(c, unicastV6)
if v4Port != v6Port {
c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
}
@@ -632,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
}
@@ -641,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -658,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV6Write(c)
+ testWrite(c, unicastV6)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNetworkUnreachable {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Write(c)
+ testWrite(c, unicastV4in6)
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV4WriteOnV6Only(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(true)
+ c.createEndpointForFlow(unicastV6Only)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNoRoute {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -728,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
}
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV6WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
-
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
- }
+ testWriteWithoutDestination(c, unicastV6)
}
func TestV4WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ testWriteWithoutDestination(c, unicastV4)
+}
+
+// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
+// that is bound to a V4 multicast address.
+func TestWriteOnBoundToV4Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
}
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
+// socket that is bound to a V4-mapped multicast address.
+func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6, multicastV6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// V6-only socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToBroadcast checks that we can send packets out of a
+// socket that is bound to the broadcast address.
+func TestWriteOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 broadcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
+// socket that is bound to the V4-mapped broadcast address.
+func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
}
@@ -814,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
defer c.cleanup()
// Create IPv4 UDP endpoint
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Read(c)
+ testRead(c, unicastV4)
var want uint64 = 1
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
@@ -837,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
@@ -847,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
-func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) {
- for _, name := range []string{"v4", "v6", "dual"} {
- t.Run(name, func(t *testing.T) {
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch name {
- case "v4":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6", "dual":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+func TestTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- var variants []string
- switch name {
- case "v4":
- variants = []string{"v4"}
- case "v6":
- variants = []string{"v6"}
- case "dual":
- variants = []string{"v6", "mapped"}
- }
+ c.createEndpointForFlow(flow)
- for _, variant := range variants {
- t.Run(variant, func(t *testing.T) {
- optFunc(t, name, networkProtocolNumber, variant)
- })
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
}
- })
- }
-}
-func TestTTL(t *testing.T) {
- payload := tcpip.SlicePayload(buffer.View(newPayload()))
-
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, typ := range []string{"unicast", "multicast"} {
- t.Run(typ, func(t *testing.T) {
- var addr tcpip.Address
- var port uint16
- switch typ {
- case "unicast":
- port = testPort
- switch variant {
- case "v4":
- addr = testAddr
- case "mapped":
- addr = testV4MappedAddr
- case "v6":
- addr = testV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- port = multicastPort
- switch variant {
- case "v4":
- addr = multicastAddr
- case "mapped":
- addr = multicastV4MappedAddr
- case "v6":
- addr = multicastV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- default:
- t.Fatal("unknown test variant")
+ var wantTTL uint8
+ if flow.isMulticast() {
+ wantTTL = multicastTTL
+ } else {
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
}
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- switch name {
- case "v4":
- case "v6":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- case "dual":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- default:
- t.Fatal("unknown test variant")
- }
-
- const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ t.Fatal(err)
}
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ }
- n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
- }
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+}
- checkerFn := checker.IPv4
- switch variant {
- case "v4", "mapped":
- case "v6":
- checkerFn = checker.IPv6
- default:
- t.Fatal("unknown test variant")
- }
- var wantTTL uint8
- var multicast bool
- switch typ {
- case "unicast":
- multicast = false
- switch variant {
- case "v4", "mapped":
- ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- case "v6":
- ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- wantTTL = multicastTTL
- multicast = true
- default:
- t.Fatal("unknown test variant")
- }
+func TestMulticastInterfaceOption(t *testing.T) {
+ for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, bindTyp := range []string{"bound", "unbound"} {
+ t.Run(bindTyp, func(t *testing.T) {
+ for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
+ t.Run(optTyp, func(t *testing.T) {
+ h := flow.header4Tuple(outgoing)
+ mcastAddr := h.dstAddr.Addr
+ localIfAddr := h.srcAddr.Addr
+
+ var ifoptSet tcpip.MulticastInterfaceOption
+ switch optTyp {
+ case "use local-addr":
+ ifoptSet.InterfaceAddr = localIfAddr
+ case "use NICID":
+ ifoptSet.NIC = 1
+ case "use local-addr and NIC":
+ ifoptSet.InterfaceAddr = localIfAddr
+ ifoptSet.NIC = 1
+ default:
+ t.Fatal("unknown test variant")
+ }
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch variant {
- case "v4", "mapped":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(flow.sockProto())
+
+ if bindTyp == "bound" {
+ // Bind the socket by connecting to the multicast address.
+ // This may have an influence on how the multicast interface
+ // is set.
+ addr := tcpip.FullAddress{
+ Addr: flow.mapAddrIfApplicable(mcastAddr),
+ Port: stackPort,
+ }
+ if err := c.ep.Connect(addr); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+ }
- b := c.getPacket(networkProtocolNumber, multicast)
- checkerFn(c.t, b,
- checker.TTL(wantTTL),
- checker.UDP(
- checker.DstPort(port),
- ),
- )
- })
- }
- })
-}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
-func TestMulticastInterfaceOption(t *testing.T) {
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- var mcastAddr, localIfAddr tcpip.Address
- switch variant {
- case "v4":
- mcastAddr = multicastAddr
- localIfAddr = stackAddr
- case "mapped":
- mcastAddr = multicastV4MappedAddr
- localIfAddr = stackAddr
- case "v6":
- mcastAddr = multicastV6Addr
- localIfAddr = stackV6Addr
- default:
- t.Fatal("unknown test variant")
- }
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: mcastAddr,
- Port: multicastPort,
+ // Verify multicast interface addr and NIC were set correctly.
+ // Note that NIC must be 1 since this is our outgoing interface.
+ ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
+ var ifoptGot tcpip.MulticastInterfaceOption
+ if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
+ c.t.Fatalf("GetSockOpt failed: %v", err)
}
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ if ifoptGot != ifoptWant {
+ c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
}
- }
-
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
+ })
+ }
+ })
+ }
+ })
+ }
}