diff options
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 1094 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_udp_unbound.cc | 254 |
4 files changed, 897 insertions, 474 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 935ac622e..ac5905772 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -249,6 +249,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // specified address is a multicast address. func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.id.LocalAddress + if isBroadcastOrMulticast(localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { if nicid == 0 { nicid = e.multicastNICID @@ -448,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } nicID := v.NIC - if v.InterfaceAddr == header.IPv4Any { + + // The interface address is considered not-set if it is empty or contains + // all-zeros. The former represent the zero-value in golang, the latter the + // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. + allZeros := header.IPv4Any + if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { if nicID == 0 { r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) if err == nil { @@ -914,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicid := addr.NIC - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. + if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + // A local unicast address was specified, verify that it's valid. nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicid == 0 { return tcpip.ErrBadLocalAddress @@ -1064,3 +1074,7 @@ func (e *endpoint) State() uint32 { // TODO(b/112063468): Translate internal state to values returned by Linux. return 0 } + +func isBroadcastOrMulticast(a tcpip.Address) bool { + return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4a3c30115..5cbb56120 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -97,7 +97,8 @@ func (e *endpoint) Resume(s *stack.Stack) { if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 { // stateBound + } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index e6a3a0c0c..9da6edce2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "fmt" "math" "math/rand" "testing" @@ -34,13 +35,19 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Addresses and ports used for testing. It is recommended that tests stick to +// using these addresses as it allows using the testFlow helper. +// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' +// represents the remote endpoint. const ( + v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + stackV4MappedAddr = v4MappedAddrPrefix + stackAddr + testV4MappedAddr = v4MappedAddrPrefix + testAddr + multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr + broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr + v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" stackAddr = "\x0a\x00\x00\x01" stackPort = 1234 @@ -48,7 +55,7 @@ const ( testPort = 4096 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - multicastPort = 1234 + broadcastAddr = header.IPv4Broadcast // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -56,6 +63,205 @@ const ( defaultMTU = 65536 ) +// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in +// a packet header. These values are used to populate a header or verify one. +// Note that because they are used in packet headers, the addresses are never in +// a V4-mapped format. +type header4Tuple struct { + srcAddr tcpip.FullAddress + dstAddr tcpip.FullAddress +} + +// testFlow implements a helper type used for sending and receiving test +// packets. A given test flow value defines 1) the socket endpoint used for the +// test and 2) the type of packet send or received on the endpoint. E.g., a +// multicastV6Only flow is a V6 multicast packet passing through a V6-only +// endpoint. The type provides helper methods to characterize the flow (e.g., +// isV4) as well as return a proper header4Tuple for it. +type testFlow int + +const ( + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket +) + +func (flow testFlow) String() string { + switch flow { + case unicastV4: + return "unicastV4" + case unicastV6: + return "unicastV6" + case unicastV6Only: + return "unicastV6Only" + case unicastV4in6: + return "unicastV4in6" + case multicastV4: + return "multicastV4" + case multicastV6: + return "multicastV6" + case multicastV6Only: + return "multicastV6Only" + case multicastV4in6: + return "multicastV4in6" + case broadcast: + return "broadcast" + case broadcastIn6: + return "broadcastIn6" + default: + return "unknown" + } +} + +// packetDirection explains if a flow is incoming (read) or outgoing (write). +type packetDirection int + +const ( + incoming packetDirection = iota + outgoing +) + +// header4Tuple returns the header4Tuple for the given flow and direction. Note +// that the tuple contains no mapped addresses as those only exist at the socket +// level but not at the packet header level. +func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { + var h header4Tuple + if flow.isV4() { + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastAddr + } else if flow.isBroadcast() { + h.dstAddr.Addr = broadcastAddr + } + } else { // IPv6 + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastV6Addr + } + } + return h +} + +func (flow testFlow) getMcastAddr() tcpip.Address { + if flow.isV4() { + return multicastAddr + } + return multicastV6Addr +} + +// mapAddrIfApplicable converts the given V4 address into its V4-mapped version +// if it is applicable to the flow. +func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { + if flow.isMapped() { + return v4MappedAddrPrefix + v4Addr + } + return v4Addr +} + +// netProto returns the protocol number used for the network packet. +func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { + if flow.isV4() { + return ipv4.ProtocolNumber + } + return ipv6.ProtocolNumber +} + +// sockProto returns the protocol number used when creating the socket +// endpoint for this flow. +func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { + switch flow { + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + return ipv6.ProtocolNumber + case unicastV4, multicastV4, broadcast: + return ipv4.ProtocolNumber + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { + if flow.isV4() { + return checker.IPv4 + } + return checker.IPv6 +} + +func (flow testFlow) isV6() bool { return !flow.isV4() } +func (flow testFlow) isV4() bool { + return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() +} + +func (flow testFlow) isV6Only() bool { + switch flow { + case unicastV6Only, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMulticast() bool { + switch flow { + case multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isBroadcast() bool { + switch flow { + case broadcast, broadcastIn6: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMapped() bool { + switch flow { + case unicastV4in6, multicastV4in6, broadcastIn6: + return true + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -65,12 +271,9 @@ type testContext struct { wq waiter.Queue } -type headers struct { - srcPort uint16 - dstPort uint16 -} - func newDualTestContext(t *testing.T, mtu uint32) *testContext { + t.Helper() + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) id, linkEP := channel.New(256, mtu, "") @@ -113,51 +316,54 @@ func (c *testContext) cleanup() { } } -func (c *testContext) createV6Endpoint(v6only bool) { +func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { + c.t.Helper() + var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + c.t.Fatal("NewEndpoint failed: ", err) } +} - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) +func (c *testContext) createEndpointForFlow(flow testFlow) { + c.t.Helper() + + c.createEndpoint(flow.sockProto()) + if flow.isV6Only() { + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + } else if flow.isBroadcast() { + if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } } } -func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { +// getPacketAndVerify reads a packet from the link endpoint and verifies the +// header against expected values from the given test flow. In addition, it +// calls any extra checker functions provided. +func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { + c.t.Helper() + select { case p := <-c.linkEP.C: - if p.Proto != protocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) - var srcAddr, dstAddr tcpip.Address - switch protocolNumber { - case ipv4.ProtocolNumber: - checkerFn = checker.IPv4 - srcAddr, dstAddr = stackAddr, testAddr - if multicast { - dstAddr = multicastAddr - } - case ipv6.ProtocolNumber: - checkerFn = checker.IPv6 - srcAddr, dstAddr = stackV6Addr, testV6Addr - if multicast { - dstAddr = multicastV6Addr - } - default: - c.t.Fatalf("unknown protocol %d", protocolNumber) - } - checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) + h := flow.header4Tuple(outgoing) + checkers := append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) return b case <-time.After(2 * time.Second): @@ -167,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult return nil } -func (c *testContext) sendV6Packet(payload []byte, h *headers) { +// injectPacket creates a packet of the given flow and with the given payload, +// and injects it into the link endpoint. +func (c *testContext) injectPacket(flow testFlow, payload []byte) { + c.t.Helper() + + h := flow.header4Tuple(incoming) + if flow.isV4() { + c.injectV4Packet(payload, &h) + } else { + c.injectV6Packet(payload, &h) + } +} + +// injectV6Packet creates a V6 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -178,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -201,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) { c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) } -func (c *testContext) sendPacket(payload []byte, h *headers) { +// injectV6Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -213,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) { TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), - SrcAddr: testAddr, - DstAddr: stackAddr, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) ip.SetChecksum(^ip.CalculateChecksum()) // Initialize the UDP header. u := header.UDP(buf[header.IPv4MinimumSize:]) u.Encode(&header.UDPFields{ - SrcPort: h.srcPort, - DstPort: h.dstPort, + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -249,7 +472,7 @@ func TestBindPortReuse(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) var eps [5]tcpip.Endpoint reusePortOpt := tcpip.ReusePortOption(1) @@ -292,9 +515,9 @@ func TestBindPortReuse(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort + port, - dstPort: stackPort, + c.injectV6Packet(payload, &header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, }) var addr tcpip.FullAddress @@ -329,13 +552,14 @@ func TestBindPortReuse(t *testing.T) { } } -func testV4Read(c *testContext) { - // Send a packet. +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + payload := newPayload() - c.sendPacket(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) + c.injectPacket(flow, payload) // Try to receive the data. we, ch := waiter.NewChannelEntry(nil) @@ -359,8 +583,9 @@ func testV4Read(c *testContext) { } // Check the peer address. - if addr.Addr != testAddr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + h := flow.header4Tuple(incoming) + if addr.Addr != h.srcAddr.Addr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) } // Check the payload. @@ -373,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { t.Fatalf("ep.Bind(...) failed: %v", err) @@ -384,7 +609,7 @@ func TestBindReservedPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) @@ -443,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -451,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to v4 mapped wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil { + if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV4ReadOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV4in6) // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -481,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4in6) } func TestV6ReadOnV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpointForFlow(unicastV6) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - // Send a packet. - payload := newPayload() - c.sendV6Packet(payload, &headers{ - srcPort: testPort, - dstPort: stackPort, - }) - - // Try to receive the data. - we, ch := waiter.NewChannelEntry(nil) - c.wq.EventRegister(&we, waiter.EventIn) - defer c.wq.EventUnregister(&we) - - var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) - if err == tcpip.ErrWouldBlock { - // Wait for data to become available. - select { - case <-ch: - v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") - } - } - - // Check the peer address. - if addr.Addr != testV6Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) - } - - // Check the payload. - if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) - } + // Test acceptance. + testRead(c, unicastV6) } func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - // Create v4 UDP endpoint. - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpointForFlow(unicastV4) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -551,45 +736,108 @@ func TestV4ReadOnV4(t *testing.T) { } // Test acceptance. - testV4Read(c) + testRead(c, unicastV4) } -func testV4Write(c *testContext) uint16 { - // Write to V4 mapped address. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != nil { - c.t.Fatalf("Write failed: %v", err) +// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast +// address and receive data sent to that address. +func TestReadOnBoundToMulticast(t *testing.T) { + // FIXME(b/128189410): multicastV4in6 currently doesn't work as + // AddMembershipOption doesn't handle V4in6 addresses. + for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to multicast address. + mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) + if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + // Join multicast group. + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } + + testRead(c, flow) + }) } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast +// address and receive broadcast data on it. +func TestV4ReadOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to broadcast address. + bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) + if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// testFailingWrite sends a packet of the given test flow into the UDP endpoint +// and verifies it fails with the provided error code. +func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { + c.t.Helper() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + + payload := buffer.View(newPayload()) + _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + }) + if gotErr != wantErr { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } +} - return udp.SourcePort() +// testWrite sends a packet of the given test flow from the UDP endpoint to the +// flow's destination address:port. It then receives it from the link endpoint +// and verifies its correctness including any additional checker functions +// provided. +func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, true, checkers...) } -func testV6Write(c *testContext) uint16 { - // Write to v6 address. +// testWriteWithoutDestination sends a packet of the given test flow from the +// UDP endpoint without giving a destination address:port. It then receives it +// from the link endpoint and verifies its correctness including any additional +// checker functions provided. +func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, false, checkers...) +} + +func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + + writeOpts := tcpip.WriteOptions{} + if setDest { + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + writeOpts = tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + } + } payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { c.t.Fatalf("Write failed: %v", err) } @@ -597,16 +845,14 @@ func testV6Write(c *testContext) uint16 { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. + // Received the packet and check the payload. + b := c.getPacketAndVerify(flow, checkers...) + var udp header.UDP + if flow.isV4() { + udp = header.UDP(header.IPv4(b).Payload()) + } else { + udp = header.UDP(header.IPv6(b).Payload()) + } if !bytes.Equal(payload, udp.Payload()) { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } @@ -615,8 +861,10 @@ func testV6Write(c *testContext) uint16 { } func testDualWrite(c *testContext) uint16 { - v4Port := testV4Write(c) - v6Port := testV6Write(c) + c.t.Helper() + + v4Port := testWrite(c, unicastV4in6) + v6Port := testWrite(c, unicastV6) if v4Port != v6Port { c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) } @@ -628,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) } @@ -637,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { @@ -654,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV6Write(c) + testWrite(c, unicastV6) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNetworkUnreachable { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) } func TestDualWriteConnectedToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Write(c) + testWrite(c, unicastV4in6) // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV4WriteOnV6Only(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(true) + c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, - }) - if err != tcpip.ErrNoRoute { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) - } + testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { @@ -724,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - payload := buffer.View(newPayload()) - _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ - To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, - }) - if err != tcpip.ErrInvalidEndpointState { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) - } + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) } func TestV6WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) - } - - // Check that we received the packet. - b := c.getPacket(ipv6.ProtocolNumber, false) - udp := header.UDP(header.IPv6(b).Payload()) - checker.IPv6(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) - - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) - } + testWriteWithoutDestination(c, unicastV6) } func TestV4WriteOnConnected(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { c.t.Fatalf("Connect failed: %v", err) } - // Write without destination. - payload := buffer.View(newPayload()) - n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) + testWriteWithoutDestination(c, unicastV4) +} + +// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket +// that is bound to a V4 multicast address. +func TestWriteOnBoundToV4Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) } - if n != int64(len(payload)) { - c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) +} + +// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a +// socket that is bound to a V4-mapped multicast address. +func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } +} - // Check that we received the packet. - b := c.getPacket(ipv4.ProtocolNumber, false) - udp := header.UDP(header.IPv4(b).Payload()) - checker.IPv4(c.t, b, - checker.UDP( - checker.DstPort(testPort), - ), - ) +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV6, multicastV6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - // Check the payload. - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// V6-only socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToBroadcast checks that we can send packets out of a +// socket that is bound to the broadcast address. +func TestWriteOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 broadcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a +// socket that is bound to the V4-mapped broadcast address. +func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) } } @@ -810,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { defer c.cleanup() // Create IPv4 UDP endpoint - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } + c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } - testV4Read(c) + testRead(c, unicastV4) var want uint64 = 1 if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { @@ -833,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() - c.createV6Endpoint(false) + c.createEndpoint(ipv6.ProtocolNumber) testDualWrite(c) @@ -843,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } -func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) { - for _, name := range []string{"v4", "v6", "dual"} { - t.Run(name, func(t *testing.T) { - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch name { - case "v4": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6", "dual": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } +func TestTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - var variants []string - switch name { - case "v4": - variants = []string{"v4"} - case "v6": - variants = []string{"v6"} - case "dual": - variants = []string{"v6", "mapped"} - } + c.createEndpointForFlow(flow) - for _, variant := range variants { - t.Run(variant, func(t *testing.T) { - optFunc(t, name, networkProtocolNumber, variant) - }) + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) } - }) - } -} -func TestTTL(t *testing.T) { - payload := tcpip.SlicePayload(buffer.View(newPayload())) - - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, typ := range []string{"unicast", "multicast"} { - t.Run(typ, func(t *testing.T) { - var addr tcpip.Address - var port uint16 - switch typ { - case "unicast": - port = testPort - switch variant { - case "v4": - addr = testAddr - case "mapped": - addr = testV4MappedAddr - case "v6": - addr = testV6Addr - default: - t.Fatal("unknown test variant") - } - case "multicast": - port = multicastPort - switch variant { - case "v4": - addr = multicastAddr - case "mapped": - addr = multicastV4MappedAddr - case "v6": - addr = multicastV6Addr - default: - t.Fatal("unknown test variant") - } - default: - t.Fatal("unknown test variant") + var wantTTL uint8 + if flow.isMulticast() { + wantTTL = multicastTTL + } else { + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - switch name { - case "v4": - case "v6": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - case "dual": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - default: - t.Fatal("unknown test variant") - } - - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + t.Fatal(err) } + wantTTL = ep.DefaultTTL() + ep.Close() + } - n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != int64(len(payload)) { - c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) - } + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } +} - checkerFn := checker.IPv4 - switch variant { - case "v4", "mapped": - case "v6": - checkerFn = checker.IPv6 - default: - t.Fatal("unknown test variant") - } - var wantTTL uint8 - var multicast bool - switch typ { - case "unicast": - multicast = false - switch variant { - case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - default: - t.Fatal("unknown test variant") - } - case "multicast": - wantTTL = multicastTTL - multicast = true - default: - t.Fatal("unknown test variant") - } +func TestMulticastInterfaceOption(t *testing.T) { + for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, bindTyp := range []string{"bound", "unbound"} { + t.Run(bindTyp, func(t *testing.T) { + for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { + t.Run(optTyp, func(t *testing.T) { + h := flow.header4Tuple(outgoing) + mcastAddr := h.dstAddr.Addr + localIfAddr := h.srcAddr.Addr + + var ifoptSet tcpip.MulticastInterfaceOption + switch optTyp { + case "use local-addr": + ifoptSet.InterfaceAddr = localIfAddr + case "use NICID": + ifoptSet.NIC = 1 + case "use local-addr and NIC": + ifoptSet.InterfaceAddr = localIfAddr + ifoptSet.NIC = 1 + default: + t.Fatal("unknown test variant") + } - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch variant { - case "v4", "mapped": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") - } + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(flow.sockProto()) + + if bindTyp == "bound" { + // Bind the socket by connecting to the multicast address. + // This may have an influence on how the multicast interface + // is set. + addr := tcpip.FullAddress{ + Addr: flow.mapAddrIfApplicable(mcastAddr), + Port: stackPort, + } + if err := c.ep.Connect(addr); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + } - b := c.getPacket(networkProtocolNumber, multicast) - checkerFn(c.t, b, - checker.TTL(wantTTL), - checker.UDP( - checker.DstPort(port), - ), - ) - }) - } - }) -} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } -func TestMulticastInterfaceOption(t *testing.T) { - setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { - for _, bindTyp := range []string{"bound", "unbound"} { - t.Run(bindTyp, func(t *testing.T) { - for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { - t.Run(optTyp, func(t *testing.T) { - var mcastAddr, localIfAddr tcpip.Address - switch variant { - case "v4": - mcastAddr = multicastAddr - localIfAddr = stackAddr - case "mapped": - mcastAddr = multicastV4MappedAddr - localIfAddr = stackAddr - case "v6": - mcastAddr = multicastV6Addr - localIfAddr = stackV6Addr - default: - t.Fatal("unknown test variant") - } - - var ifoptSet tcpip.MulticastInterfaceOption - switch optTyp { - case "use local-addr": - ifoptSet.InterfaceAddr = localIfAddr - case "use NICID": - ifoptSet.NIC = 1 - case "use local-addr and NIC": - ifoptSet.InterfaceAddr = localIfAddr - ifoptSet.NIC = 1 - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if bindTyp == "bound" { - // Bind the socket by connecting to the multicast address. - // This may have an influence on how the multicast interface - // is set. - addr := tcpip.FullAddress{ - Addr: mcastAddr, - Port: multicastPort, + // Verify multicast interface addr and NIC were set correctly. + // Note that NIC must be 1 since this is our outgoing interface. + ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} + var ifoptGot tcpip.MulticastInterfaceOption + if err := c.ep.GetSockOpt(&ifoptGot); err != nil { + c.t.Fatalf("GetSockOpt failed: %v", err) } - if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + if ifoptGot != ifoptWant { + c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) } - } - - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - - // Verify multicast interface addr and NIC were set correctly. - // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} - var ifoptGot tcpip.MulticastInterfaceOption - if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) - } - }) - } - }) - } - }) + }) + } + }) + } + }) + } } diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index d9aa7ff3f..67d29af0a 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -30,6 +30,7 @@ namespace gvisor { namespace testing { constexpr char kMulticastAddress[] = "224.0.2.1"; +constexpr char kBroadcastAddress[] = "255.255.255.255"; TestAddress V4Multicast() { TestAddress t("V4Multicast"); @@ -40,6 +41,15 @@ TestAddress V4Multicast() { return t; } +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kBroadcastAddress); + return t; +} + // Check that packets are not received without a group membership. Default send // interface configured by bind. TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { @@ -1426,5 +1436,249 @@ TEST_P(IPv4UDPUnboundSocketPairTest, } } +// Check that a receiving socket can bind to the multicast address before +// joining the group and receive data once the group has been joined. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the multicast address and won't +// receive multicast data if it hasn't joined the group. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we don't receive the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Check that a socket can bind to a multicast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the multicast address. + auto sender_addr = V4Multicast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the broadcast address and receive +// broadcast packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the broadcast address. + auto receiver_addr = V4Broadcast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a broadcast packet on the first socket out the loopback interface. + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + // Note: Binding to the loopback interface makes the broadcast go out of it. + auto sender_bind_addr = V4Loopback(); + ASSERT_THAT(bind(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); + auto sendto_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a socket can bind to the broadcast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the broadcast address. + auto sender_addr = V4Broadcast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + } // namespace testing } // namespace gvisor |