diff options
author | Chris Kuiper <ckuiper@google.com> | 2019-08-21 22:53:07 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-08-21 22:54:25 -0700 |
commit | 8d9276ed564ffef5d12426e839aeef7de5164d7d (patch) | |
tree | 460163aed957d3fb241016a69c975857d0feefa4 | |
parent | 5fd63d1c7fc8ea8d0aff30abc6403aa4491b6f81 (diff) |
Support binding to multicast and broadcast addresses
This fixes the issue of not being able to bind to either a multicast or
broadcast address as well as to send and receive data from it. The way to solve
this is to treat these addresses similar to the ANY address and register their
transport endpoint ID with the global stack's demuxer rather than the NIC's.
That way there is no need to require an endpoint with that multicast or
broadcast address. The stack's demuxer is in fact the only correct one to use,
because neither broadcast- nor multicast-bound sockets care which NIC a
packet was received on (for multicast a join is still needed to receive packets
on a NIC).
I also took the liberty of refactoring udp_test.go to consolidate a lot of
duplicate code and make it easier to create repetitive tests that test the same
feature for a variety of packet and socket types. For this purpose I created a
"flowType" that represents two things: 1) the type of packet being sent or
received and 2) the type of socket used for the test. E.g., a "multicastV4in6"
flow represents a V4-mapped multicast packet run through a V6-dual socket.
This allows writing significantly simpler tests. A nice example is testTTL().
PiperOrigin-RevId: 264766909
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 1094 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_udp_unbound.cc | 254 |
4 files changed, 897 insertions, 474 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 935ac622e..ac5905772 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -249,6 +249,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // specified address is a multicast address. func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.id.LocalAddress + if isBroadcastOrMulticast(localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { if nicid == 0 { nicid = e.multicastNICID @@ -448,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } nicID := v.NIC - if v.InterfaceAddr == header.IPv4Any { + + // The interface address is considered not-set if it is empty or contains + // all-zeros. The former represent the zero-value in golang, the latter the + // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. + allZeros := header.IPv4Any + if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { if nicID == 0 { r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) if err == nil { @@ -914,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicid := addr.NIC - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. + if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + // A local unicast address was specified, verify that it's valid. nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicid == 0 { return tcpip.ErrBadLocalAddress @@ -1064,3 +1074,7 @@ func (e *endpoint) State() uint32 { // TODO(b/112063468): Translate internal state to values returned by Linux. return 0 } + +func isBroadcastOrMulticast(a tcpip.Address) bool { + return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4a3c30115..5cbb56120 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -97,7 +97,8 @@ func (e *endpoint) Resume(s *stack.Stack) { if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 { // stateBound + } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index e6a3a0c0c..9da6edce2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "fmt" "math" "math/rand" "testing" @@ -34,13 +35,19 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Addresses and ports used for testing. It is recommended that tests stick to +// using these addresses as it allows using the testFlow helper. +// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' +// represents the remote endpoint. const ( + v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + stackV4MappedAddr = v4MappedAddrPrefix + stackAddr + testV4MappedAddr = v4MappedAddrPrefix + testAddr + multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr + broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr + v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" stackAddr = "\x0a\x00\x00\x01" stackPort = 1234 @@ -48,7 +55,7 @@ const ( testPort = 4096 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - multicastPort = 1234 + broadcastAddr = header.IPv4Broadcast // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -56,6 +63,205 @@ const ( defaultMTU = 65536 ) +// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in +// a packet header. These values are used to populate a header or verify one. +// Note that because they are used in packet headers, the addresses are never in +// a V4-mapped format. +type header4Tuple struct { + srcAddr tcpip.FullAddress + dstAddr tcpip.FullAddress +} + +// testFlow implements a helper type used for sending and receiving test +// packets. A given test flow value defines 1) the socket endpoint used for the +// test and 2) the type of packet send or received on the endpoint. E.g., a +// multicastV6Only flow is a V6 multicast packet passing through a V6-only +// endpoint. The type provides helper methods to characterize the flow (e.g., +// isV4) as well as return a proper header4Tuple for it. +type testFlow int + +const ( + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket +) + +func (flow testFlow) String() string { + switch flow { + case unicastV4: + return "unicastV4" + case unicastV6: + return "unicastV6" + case unicastV6Only: + return "unicastV6Only" + case unicastV4in6: + return "unicastV4in6" + case multicastV4: + return "multicastV4" + case multicastV6: + return "multicastV6" + case multicastV6Only: + return "multicastV6Only" + case multicastV4in6: + return "multicastV4in6" + case broadcast: + return "broadcast" + case broadcastIn6: + return "broadcastIn6" + default: + return "unknown" + } +} + +// packetDirection explains if a flow is incoming (read) or outgoing (write). +type packetDirection int + +const ( + incoming packetDirection = iota + outgoing +) + +// header4Tuple returns the header4Tuple for the given flow and direction. Note +// that the tuple contains no mapped addresses as those only exist at the socket +// level but not at the packet header level. +func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { + var h header4Tuple + if flow.isV4() { + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastAddr + } else if flow.isBroadcast() { + h.dstAddr.Addr = broadcastAddr + } + } else { // IPv6 + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastV6Addr + } + } + return h +} + +func (flow testFlow) getMcastAddr() tcpip.Address { + if flow.isV4() { + return multicastAddr + } + return multicastV6Addr +} + +// mapAddrIfApplicable converts the given V4 address into its V4-mapped version +// if it is applicable to the flow. +func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { + if flow.isMapped() { + return v4MappedAddrPrefix + v4Addr + } + return v4Addr +} + +// netProto returns the protocol number used for the network packet. +func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { + if flow.isV4() { + return ipv4.ProtocolNumber + } + return ipv6.ProtocolNumber +} + +// sockProto returns the protocol number used when creating the socket +// endpoint for this flow. +func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { + switch flow { + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + return ipv6.ProtocolNumber + case unicastV4, multicastV4, broadcast: + return ipv4.ProtocolNumber + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { + if flow.isV4() { + return checker.IPv4 + } + return checker.IPv6 +} + +func (flow testFlow) isV6() bool { return !flow.isV4() } +func (flow testFlow) isV4() bool { + return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() +} + +func (flow testFlow) isV6Only() bool { + switch flow { + case unicastV6Only, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMulticast() bool { + switch flow { + case multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isBroadcast() bool { + switch flow { + case broadcast, broadcastIn6: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMapped() bool { + switch flow { + case unicastV4in6, multicastV4in6, broadcastIn6: + return true + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -65,12 +271,9 @@ type testContext struct { wq waiter.Queue } -type headers struct { - srcPort uint16 - dstPort uint16 -} - func newDualTestContext(t *testing.T, mtu uint32) *testContext { + t.Helper() + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) id, linkEP := channel.New(256, mtu, "") @@ -113,51 +316,54 @@ func (c *testContext) cleanup() { } } -func (c *testContext) createV6Endpoint(v6only bool) { +func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { + c.t.Helper() + var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + c.t.Fatal("NewEndpoint failed: ", err) } +} - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) +func (c *testContext) createEndpointForFlow(flow testFlow) { + c.t.Helper() + + c.createEndpoint(flow.sockProto()) + if flow.isV6Only() { + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + } else if flow.isBroadcast() { + if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } } } -func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { +// getPacketAndVerify reads a packet from the link endpoint and verifies the +// header against expected values from the given test flow. In addition, it +// calls any extra checker functions provided. +func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { + c.t.Helper() + select { case p := <-c.linkEP.C: - if p.Proto != protocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) - var srcAddr, dstAddr tcpip.Address - switch protocolNumber { - case ipv4.ProtocolNumber: - checkerFn = checker.IPv4 - srcAddr, dstAddr = stackAddr, testAddr - if multicast { - dstAddr = multicastAddr - } - case ipv6.ProtocolNumber: - checkerFn = checker.IPv6 - srcAddr, dstAddr = stackV6Addr, testV6Addr - if multicast { - dstAddr = multicastV6Addr - } - default: - c.t.Fatalf("unknown protocol %d", protocolNumber) - } - checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) + h := flow.header4Tuple(outgoing) + checkers := append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) return b case <-time.After(2 * time.Second): @@ -167,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult return nil } -func (c *testContext) sendV6Packet(payload []byte, h *headers) { +// injectPacket creates a packet of the given flow and with the given payload, +// and injects it into the link endpoint. +func (c *testContext) injectPacket(flow testFlow, payload []byte) { + c.t.Helper() + + h := flow.header4Tuple(incoming) + if flow.isV4() { + c.injectV4Packet(payload, &h) + } else { + c.injectV6Packet(payload, &h) + } +} + +// injectV6Packet creates a V6 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -178,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -201,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -func (c *testContext) sendPacket(payload []byte, h *headers) { +// injectV6Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -213,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) { TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testAddr, - DstAddr: stackAddr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) ip.SetChecksum(^ip.CalculateChecksum()) // Initialize the UDP header. u := header.UDP(buf[header.IPv4MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -249,7 +472,7 @@ func TestBindPortReuse(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) var eps [5]tcpip.Endpoint reusePortOpt := tcpip.ReusePortOption(1) @@ -292,9 +515,9 @@ func TestBindPortReuse(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort + port, - dstPort: stackPort, + c.injectV6Packet(payload, &header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, }) var addr tcpip.FullAddress @@ -329,13 +552,14 @@ func TestBindPortReuse(t *testing.T) { } } -func testV4Read(c *testContext) { - // Send a packet. +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + payload := newPayload() - c.sendPacket(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) + c.injectPacket(flow, payload) // Try to receive the data. we, ch := waiter.NewChannelEntry(nil) @@ -359,8 +583,9 @@ func testV4Read(c *testContext) { } // Check the peer address. - if addr.Addr != testAddr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + h := flow.header4Tuple(incoming) + if addr.Addr != h.srcAddr.Addr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) } // Check the payload. @@ -373,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { t.Fatalf("ep.Bind(...) failed: %v", err) @@ -384,7 +609,7 @@ func TestBindReservedPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) @@ -443,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -451,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil { + if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -481,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV6ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - // Send a packet. - payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") - } - } - - // Check the peer address. - if addr.Addr != testV6Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) - } + // Test acceptance. + testRead(c, unicastV6) } func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - // Create v4 UDP endpoint. - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpointForFlow(unicastV4) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -551,45 +736,108 @@ func TestV4ReadOnV4(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4) } -func testV4Write(c *testContext) uint16 { - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != nil { - c.t.Fatalf("Write failed: %v", err) +// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast +// address and receive data sent to that address. +func TestReadOnBoundToMulticast(t *testing.T) { + // FIXME(b/128189410): multicastV4in6 currently doesn't work as + // AddMembershipOption doesn't handle V4in6 addresses. + for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to multicast address. + mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) + if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + // Join multicast group. + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } + + testRead(c, flow) + }) } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast +// address and receive broadcast data on it. +func TestV4ReadOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to broadcast address. + bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) + if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// testFailingWrite sends a packet of the given test flow into the UDP endpoint +// and verifies it fails with the provided error code. +func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { + c.t.Helper() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + + payload := buffer.View(newPayload()) + _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + }) + if gotErr != wantErr { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } +} - return udp.SourcePort() +// testWrite sends a packet of the given test flow from the UDP endpoint to the +// flow's destination address:port. It then receives it from the link endpoint +// and verifies its correctness including any additional checker functions +// provided. +func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, true, checkers...) } -func testV6Write(c *testContext) uint16 { - // Write to v6 address. +// testWriteWithoutDestination sends a packet of the given test flow from the +// UDP endpoint without giving a destination address:port. It then receives it +// from the link endpoint and verifies its correctness including any additional +// checker functions provided. +func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, false, checkers...) +} + +func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + + writeOpts := tcpip.WriteOptions{} + if setDest { + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + writeOpts = tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + } + } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("Write failed: %v", err) } @@ -597,16 +845,14 @@ func testV6Write(c *testContext) uint16 { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. + // Received the packet and check the payload. + b := c.getPacketAndVerify(flow, checkers...) + var udp header.UDP + if flow.isV4() { + udp = header.UDP(header.IPv4(b).Payload()) + } else { + udp = header.UDP(header.IPv6(b).Payload()) + } if !bytes.Equal(payload, udp.Payload()) { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } @@ -615,8 +861,10 @@ func testV6Write(c *testContext) uint16 { } func testDualWrite(c *testContext) uint16 { - v4Port := testV4Write(c) - v6Port := testV6Write(c) + c.t.Helper() + + v4Port := testWrite(c, unicastV4in6) + v6Port := testWrite(c, unicastV6) if v4Port != v6Port { c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) } @@ -628,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) } @@ -637,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -654,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV6Write(c) + testWrite(c, unicastV6) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNetworkUnreachable { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) } func TestDualWriteConnectedToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Write(c) + testWrite(c, unicastV4in6) // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV4WriteOnV6Only(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(true) + c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNoRoute { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -724,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV6WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } + testWriteWithoutDestination(c, unicastV6) } func TestV4WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) + testWriteWithoutDestination(c, unicastV4) +} + +// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket +// that is bound to a V4 multicast address. +func TestWriteOnBoundToV4Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a +// socket that is bound to a V4-mapped multicast address. +func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV6, multicastV6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// V6-only socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToBroadcast checks that we can send packets out of a +// socket that is bound to the broadcast address. +func TestWriteOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 broadcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a +// socket that is bound to the V4-mapped broadcast address. +func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } } @@ -810,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { defer c.cleanup() // Create IPv4 UDP endpoint - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Read(c) + testRead(c, unicastV4) var want uint64 = 1 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { @@ -833,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) @@ -843,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } -func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) { - for _, name := range []string{"v4", "v6", "dual"} { - t.Run(name, func(t *testing.T) { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch name { - case "v4": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } +func TestTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - var variants []string - switch name { - case "v4": - variants = []string{"v4"} - case "v6": - variants = []string{"v6"} - case "dual": - variants = []string{"v6", "mapped"} - } + c.createEndpointForFlow(flow) - for _, variant := range variants { - t.Run(variant, func(t *testing.T) { - optFunc(t, name, networkProtocolNumber, variant) - }) + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) } - }) - } -} -func TestTTL(t *testing.T) { - payload := tcpip.SlicePayload(buffer.View(newPayload())) - - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, typ := range []string{"unicast", "multicast"} { - t.Run(typ, func(t *testing.T) { - var addr tcpip.Address - var port uint16 - switch typ { - case "unicast": - port = testPort - switch variant { - case "v4": - addr = testAddr - case "mapped": - addr = testV4MappedAddr - case "v6": - addr = testV6Addr - default: - t.Fatal("unknown test variant") - } - case "multicast": - port = multicastPort - switch variant { - case "v4": - addr = multicastAddr - case "mapped": - addr = multicastV4MappedAddr - case "v6": - addr = multicastV6Addr - default: - t.Fatal("unknown test variant") - } - default: - t.Fatal("unknown test variant") + var wantTTL uint8 + if flow.isMulticast() { + wantTTL = multicastTTL + } else { + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - switch name { - case "v4": - case "v6": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - case "dual": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - default: - t.Fatal("unknown test variant") - } - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + t.Fatal(err) } + wantTTL = ep.DefaultTTL() + ep.Close() + } - n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) - } + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } +} - checkerFn := checker.IPv4 - switch variant { - case "v4", "mapped": - case "v6": - checkerFn = checker.IPv6 - default: - t.Fatal("unknown test variant") - } - var wantTTL uint8 - var multicast bool - switch typ { - case "unicast": - multicast = false - switch variant { - case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - default: - t.Fatal("unknown test variant") - } - case "multicast": - wantTTL = multicastTTL - multicast = true - default: - t.Fatal("unknown test variant") - } +func TestMulticastInterfaceOption(t *testing.T) { + for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, bindTyp := range []string{"bound", "unbound"} { + t.Run(bindTyp, func(t *testing.T) { + for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { + t.Run(optTyp, func(t *testing.T) { + h := flow.header4Tuple(outgoing) + mcastAddr := h.dstAddr.Addr + localIfAddr := h.srcAddr.Addr + + var ifoptSet tcpip.MulticastInterfaceOption + switch optTyp { + case "use local-addr": + ifoptSet.InterfaceAddr = localIfAddr + case "use NICID": + ifoptSet.NIC = 1 + case "use local-addr and NIC": + ifoptSet.InterfaceAddr = localIfAddr + ifoptSet.NIC = 1 + default: + t.Fatal("unknown test variant") + } - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch variant { - case "v4", "mapped": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(flow.sockProto()) + + if bindTyp == "bound" { + // Bind the socket by connecting to the multicast address. + // This may have an influence on how the multicast interface + // is set. + addr := tcpip.FullAddress{ + Addr: flow.mapAddrIfApplicable(mcastAddr), + Port: stackPort, + } + if err := c.ep.Connect(addr); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + } - b := c.getPacket(networkProtocolNumber, multicast) - checkerFn(c.t, b, - checker.TTL(wantTTL), - checker.UDP( - checker.DstPort(port), - ), - ) - }) - } - }) -} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } -func TestMulticastInterfaceOption(t *testing.T) { - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - var mcastAddr, localIfAddr tcpip.Address - switch variant { - case "v4": - mcastAddr = multicastAddr - localIfAddr = stackAddr - case "mapped": - mcastAddr = multicastV4MappedAddr - localIfAddr = stackAddr - case "v6": - mcastAddr = multicastV6Addr - localIfAddr = stackV6Addr - default: - t.Fatal("unknown test variant") - } - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: mcastAddr, - Port: multicastPort, + // Verify multicast interface addr and NIC were set correctly. + // Note that NIC must be 1 since this is our outgoing interface. + ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} + var ifoptGot tcpip.MulticastInterfaceOption + if err := c.ep.GetSockOpt(&ifoptGot); err != nil { + c.t.Fatalf("GetSockOpt failed: %v", err) } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + if ifoptGot != ifoptWant { + c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) + }) + } + }) + } + }) + } } diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index d9aa7ff3f..67d29af0a 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -30,6 +30,7 @@ namespace gvisor { namespace testing { constexpr char kMulticastAddress[] = "224.0.2.1"; +constexpr char kBroadcastAddress[] = "255.255.255.255"; TestAddress V4Multicast() { TestAddress t("V4Multicast"); @@ -40,6 +41,15 @@ TestAddress V4Multicast() { return t; } +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kBroadcastAddress); + return t; +} + // Check that packets are not received without a group membership. Default send // interface configured by bind. TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { @@ -1426,5 +1436,249 @@ TEST_P(IPv4UDPUnboundSocketPairTest, } } +// Check that a receiving socket can bind to the multicast address before +// joining the group and receive data once the group has been joined. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the multicast address and won't +// receive multicast data if it hasn't joined the group. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we don't receive the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Check that a socket can bind to a multicast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the multicast address. + auto sender_addr = V4Multicast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the broadcast address and receive +// broadcast packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the broadcast address. + auto receiver_addr = V4Broadcast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a broadcast packet on the first socket out the loopback interface. + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + // Note: Binding to the loopback interface makes the broadcast go out of it. + auto sender_bind_addr = V4Loopback(); + ASSERT_THAT(bind(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); + auto sendto_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a socket can bind to the broadcast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the broadcast address. + auto sender_addr = V4Broadcast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + } // namespace testing } // namespace gvisor |