summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go20
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go3
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go1094
3 files changed, 643 insertions, 474 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 935ac622e..ac5905772 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -249,6 +249,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// specified address is a multicast address.
func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.id.LocalAddress
+ if isBroadcastOrMulticast(localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
if nicid == 0 {
nicid = e.multicastNICID
@@ -448,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
nicID := v.NIC
- if v.InterfaceAddr == header.IPv4Any {
+
+ // The interface address is considered not-set if it is empty or contains
+ // all-zeros. The former represent the zero-value in golang, the latter the
+ // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
+ allZeros := header.IPv4Any
+ if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
if nicID == 0 {
r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
if err == nil {
@@ -914,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
nicid := addr.NIC
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
+ if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ // A local unicast address was specified, verify that it's valid.
nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicid == 0 {
return tcpip.ErrBadLocalAddress
@@ -1064,3 +1074,7 @@ func (e *endpoint) State() uint32 {
// TODO(b/112063468): Translate internal state to values returned by Linux.
return 0
}
+
+func isBroadcastOrMulticast(a tcpip.Address) bool {
+ return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 4a3c30115..5cbb56120 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -97,7 +97,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err != nil {
panic(err)
}
- } else if len(e.id.LocalAddress) != 0 { // stateBound
+ } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound
+ // A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index e6a3a0c0c..9da6edce2 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "fmt"
"math"
"math/rand"
"testing"
@@ -34,13 +35,19 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// Addresses and ports used for testing. It is recommended that tests stick to
+// using these addresses as it allows using the testFlow helper.
+// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
+// represents the remote endpoint.
const (
+ v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
+ testV4MappedAddr = v4MappedAddrPrefix + testAddr
+ multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
+ broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
+ v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
@@ -48,7 +55,7 @@ const (
testPort = 4096
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- multicastPort = 1234
+ broadcastAddr = header.IPv4Broadcast
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -56,6 +63,205 @@ const (
defaultMTU = 65536
)
+// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
+// a packet header. These values are used to populate a header or verify one.
+// Note that because they are used in packet headers, the addresses are never in
+// a V4-mapped format.
+type header4Tuple struct {
+ srcAddr tcpip.FullAddress
+ dstAddr tcpip.FullAddress
+}
+
+// testFlow implements a helper type used for sending and receiving test
+// packets. A given test flow value defines 1) the socket endpoint used for the
+// test and 2) the type of packet send or received on the endpoint. E.g., a
+// multicastV6Only flow is a V6 multicast packet passing through a V6-only
+// endpoint. The type provides helper methods to characterize the flow (e.g.,
+// isV4) as well as return a proper header4Tuple for it.
+type testFlow int
+
+const (
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+)
+
+func (flow testFlow) String() string {
+ switch flow {
+ case unicastV4:
+ return "unicastV4"
+ case unicastV6:
+ return "unicastV6"
+ case unicastV6Only:
+ return "unicastV6Only"
+ case unicastV4in6:
+ return "unicastV4in6"
+ case multicastV4:
+ return "multicastV4"
+ case multicastV6:
+ return "multicastV6"
+ case multicastV6Only:
+ return "multicastV6Only"
+ case multicastV4in6:
+ return "multicastV4in6"
+ case broadcast:
+ return "broadcast"
+ case broadcastIn6:
+ return "broadcastIn6"
+ default:
+ return "unknown"
+ }
+}
+
+// packetDirection explains if a flow is incoming (read) or outgoing (write).
+type packetDirection int
+
+const (
+ incoming packetDirection = iota
+ outgoing
+)
+
+// header4Tuple returns the header4Tuple for the given flow and direction. Note
+// that the tuple contains no mapped addresses as those only exist at the socket
+// level but not at the packet header level.
+func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
+ var h header4Tuple
+ if flow.isV4() {
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastAddr
+ } else if flow.isBroadcast() {
+ h.dstAddr.Addr = broadcastAddr
+ }
+ } else { // IPv6
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastV6Addr
+ }
+ }
+ return h
+}
+
+func (flow testFlow) getMcastAddr() tcpip.Address {
+ if flow.isV4() {
+ return multicastAddr
+ }
+ return multicastV6Addr
+}
+
+// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
+// if it is applicable to the flow.
+func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
+ if flow.isMapped() {
+ return v4MappedAddrPrefix + v4Addr
+ }
+ return v4Addr
+}
+
+// netProto returns the protocol number used for the network packet.
+func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
+ if flow.isV4() {
+ return ipv4.ProtocolNumber
+ }
+ return ipv6.ProtocolNumber
+}
+
+// sockProto returns the protocol number used when creating the socket
+// endpoint for this flow.
+func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
+ switch flow {
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ return ipv6.ProtocolNumber
+ case unicastV4, multicastV4, broadcast:
+ return ipv4.ProtocolNumber
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
+ if flow.isV4() {
+ return checker.IPv4
+ }
+ return checker.IPv6
+}
+
+func (flow testFlow) isV6() bool { return !flow.isV4() }
+func (flow testFlow) isV4() bool {
+ return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
+}
+
+func (flow testFlow) isV6Only() bool {
+ switch flow {
+ case unicastV6Only, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMulticast() bool {
+ switch flow {
+ case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isBroadcast() bool {
+ switch flow {
+ case broadcast, broadcastIn6:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMapped() bool {
+ switch flow {
+ case unicastV4in6, multicastV4in6, broadcastIn6:
+ return true
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -65,12 +271,9 @@ type testContext struct {
wq waiter.Queue
}
-type headers struct {
- srcPort uint16
- dstPort uint16
-}
-
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
+ t.Helper()
+
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(256, mtu, "")
@@ -113,51 +316,54 @@ func (c *testContext) cleanup() {
}
}
-func (c *testContext) createV6Endpoint(v6only bool) {
+func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
+ c.t.Helper()
+
var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
+ c.t.Fatal("NewEndpoint failed: ", err)
}
+}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
+func (c *testContext) createEndpointForFlow(flow testFlow) {
+ c.t.Helper()
+
+ c.createEndpoint(flow.sockProto())
+ if flow.isV6Only() {
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ } else if flow.isBroadcast() {
+ if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
}
}
-func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
+// getPacketAndVerify reads a packet from the link endpoint and verifies the
+// header against expected values from the given test flow. In addition, it
+// calls any extra checker functions provided.
+func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
+ c.t.Helper()
+
select {
case p := <-c.linkEP.C:
- if p.Proto != protocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
- var srcAddr, dstAddr tcpip.Address
- switch protocolNumber {
- case ipv4.ProtocolNumber:
- checkerFn = checker.IPv4
- srcAddr, dstAddr = stackAddr, testAddr
- if multicast {
- dstAddr = multicastAddr
- }
- case ipv6.ProtocolNumber:
- checkerFn = checker.IPv6
- srcAddr, dstAddr = stackV6Addr, testV6Addr
- if multicast {
- dstAddr = multicastV6Addr
- }
- default:
- c.t.Fatalf("unknown protocol %d", protocolNumber)
- }
- checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
+ h := flow.header4Tuple(outgoing)
+ checkers := append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
return b
case <-time.After(2 * time.Second):
@@ -167,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult
return nil
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers) {
+// injectPacket creates a packet of the given flow and with the given payload,
+// and injects it into the link endpoint.
+func (c *testContext) injectPacket(flow testFlow, payload []byte) {
+ c.t.Helper()
+
+ h := flow.header4Tuple(incoming)
+ if flow.isV4() {
+ c.injectV4Packet(payload, &h)
+ } else {
+ c.injectV6Packet(payload, &h)
+ }
+}
+
+// injectV6Packet creates a V6 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -178,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -201,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-func (c *testContext) sendPacket(payload []byte, h *headers) {
+// injectV6Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -213,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) {
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: testAddr,
- DstAddr: stackAddr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -249,7 +472,7 @@ func TestBindPortReuse(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
var eps [5]tcpip.Endpoint
reusePortOpt := tcpip.ReusePortOption(1)
@@ -292,9 +515,9 @@ func TestBindPortReuse(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort + port,
- dstPort: stackPort,
+ c.injectV6Packet(payload, &header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
})
var addr tcpip.FullAddress
@@ -329,13 +552,14 @@ func TestBindPortReuse(t *testing.T) {
}
}
-func testV4Read(c *testContext) {
- // Send a packet.
+// testRead sends a packet of the given test flow into the stack by injecting it
+// into the link endpoint. It then reads it from the UDP endpoint and verifies
+// its correctness.
+func testRead(c *testContext, flow testFlow) {
+ c.t.Helper()
+
payload := newPayload()
- c.sendPacket(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
+ c.injectPacket(flow, payload)
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
@@ -359,8 +583,9 @@ func testV4Read(c *testContext) {
}
// Check the peer address.
- if addr.Addr != testAddr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ h := flow.header4Tuple(incoming)
+ if addr.Addr != h.srcAddr.Addr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr)
}
// Check the payload.
@@ -373,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
t.Fatalf("ep.Bind(...) failed: %v", err)
@@ -384,7 +609,7 @@ func TestBindReservedPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
@@ -443,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -451,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil {
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -481,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV6ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- // Send a packet.
- payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.EventIn)
- defer c.wq.EventUnregister(&we)
-
- var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- // Wait for data to become available.
- select {
- case <-ch:
- v, _, err = c.ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for data")
- }
- }
-
- // Check the peer address.
- if addr.Addr != testV6Addr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
- }
-
- // Check the payload.
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
- }
+ // Test acceptance.
+ testRead(c, unicastV6)
}
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- // Create v4 UDP endpoint.
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpointForFlow(unicastV4)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -551,45 +736,108 @@ func TestV4ReadOnV4(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4)
}
-func testV4Write(c *testContext) uint16 {
- // Write to V4 mapped address.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
+// address and receive data sent to that address.
+func TestReadOnBoundToMulticast(t *testing.T) {
+ // FIXME(b/128189410): multicastV4in6 currently doesn't work as
+ // AddMembershipOption doesn't handle V4in6 addresses.
+ for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to multicast address.
+ mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ // Join multicast group.
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
+
+ testRead(c, flow)
+ })
}
- if n != int64(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
+// address and receive broadcast data on it.
+func TestV4ReadOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to broadcast address.
+ bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// testFailingWrite sends a packet of the given test flow into the UDP endpoint
+// and verifies it fails with the provided error code.
+func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
+ c.t.Helper()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+
+ payload := buffer.View(newPayload())
+ _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ })
+ if gotErr != wantErr {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
+}
- return udp.SourcePort()
+// testWrite sends a packet of the given test flow from the UDP endpoint to the
+// flow's destination address:port. It then receives it from the link endpoint
+// and verifies its correctness including any additional checker functions
+// provided.
+func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, true, checkers...)
}
-func testV6Write(c *testContext) uint16 {
- // Write to v6 address.
+// testWriteWithoutDestination sends a packet of the given test flow from the
+// UDP endpoint without giving a destination address:port. It then receives it
+// from the link endpoint and verifies its correctness including any additional
+// checker functions provided.
+func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, false, checkers...)
+}
+
+func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+
+ writeOpts := tcpip.WriteOptions{}
+ if setDest {
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+ writeOpts = tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ }
+ }
payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
@@ -597,16 +845,14 @@ func testV6Write(c *testContext) uint16 {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
+ // Received the packet and check the payload.
+ b := c.getPacketAndVerify(flow, checkers...)
+ var udp header.UDP
+ if flow.isV4() {
+ udp = header.UDP(header.IPv4(b).Payload())
+ } else {
+ udp = header.UDP(header.IPv6(b).Payload())
+ }
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
@@ -615,8 +861,10 @@ func testV6Write(c *testContext) uint16 {
}
func testDualWrite(c *testContext) uint16 {
- v4Port := testV4Write(c)
- v6Port := testV6Write(c)
+ c.t.Helper()
+
+ v4Port := testWrite(c, unicastV4in6)
+ v6Port := testWrite(c, unicastV6)
if v4Port != v6Port {
c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
}
@@ -628,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
}
@@ -637,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -654,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV6Write(c)
+ testWrite(c, unicastV6)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNetworkUnreachable {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Write(c)
+ testWrite(c, unicastV4in6)
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV4WriteOnV6Only(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(true)
+ c.createEndpointForFlow(unicastV6Only)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNoRoute {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -724,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
}
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV6WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != int64(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
-
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
- }
+ testWriteWithoutDestination(c, unicastV6)
}
func TestV4WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ testWriteWithoutDestination(c, unicastV4)
+}
+
+// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
+// that is bound to a V4 multicast address.
+func TestWriteOnBoundToV4Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
}
- if n != int64(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
+// socket that is bound to a V4-mapped multicast address.
+func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6, multicastV6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// V6-only socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToBroadcast checks that we can send packets out of a
+// socket that is bound to the broadcast address.
+func TestWriteOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 broadcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
+// socket that is bound to the V4-mapped broadcast address.
+func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
}
@@ -810,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
defer c.cleanup()
// Create IPv4 UDP endpoint
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Read(c)
+ testRead(c, unicastV4)
var want uint64 = 1
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
@@ -833,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
@@ -843,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
-func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) {
- for _, name := range []string{"v4", "v6", "dual"} {
- t.Run(name, func(t *testing.T) {
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch name {
- case "v4":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6", "dual":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+func TestTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- var variants []string
- switch name {
- case "v4":
- variants = []string{"v4"}
- case "v6":
- variants = []string{"v6"}
- case "dual":
- variants = []string{"v6", "mapped"}
- }
+ c.createEndpointForFlow(flow)
- for _, variant := range variants {
- t.Run(variant, func(t *testing.T) {
- optFunc(t, name, networkProtocolNumber, variant)
- })
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
}
- })
- }
-}
-func TestTTL(t *testing.T) {
- payload := tcpip.SlicePayload(buffer.View(newPayload()))
-
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, typ := range []string{"unicast", "multicast"} {
- t.Run(typ, func(t *testing.T) {
- var addr tcpip.Address
- var port uint16
- switch typ {
- case "unicast":
- port = testPort
- switch variant {
- case "v4":
- addr = testAddr
- case "mapped":
- addr = testV4MappedAddr
- case "v6":
- addr = testV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- port = multicastPort
- switch variant {
- case "v4":
- addr = multicastAddr
- case "mapped":
- addr = multicastV4MappedAddr
- case "v6":
- addr = multicastV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- default:
- t.Fatal("unknown test variant")
+ var wantTTL uint8
+ if flow.isMulticast() {
+ wantTTL = multicastTTL
+ } else {
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
}
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- switch name {
- case "v4":
- case "v6":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- case "dual":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- default:
- t.Fatal("unknown test variant")
- }
-
- const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ t.Fatal(err)
}
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ }
- n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != int64(len(payload)) {
- c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
- }
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+}
- checkerFn := checker.IPv4
- switch variant {
- case "v4", "mapped":
- case "v6":
- checkerFn = checker.IPv6
- default:
- t.Fatal("unknown test variant")
- }
- var wantTTL uint8
- var multicast bool
- switch typ {
- case "unicast":
- multicast = false
- switch variant {
- case "v4", "mapped":
- ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- case "v6":
- ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- wantTTL = multicastTTL
- multicast = true
- default:
- t.Fatal("unknown test variant")
- }
+func TestMulticastInterfaceOption(t *testing.T) {
+ for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, bindTyp := range []string{"bound", "unbound"} {
+ t.Run(bindTyp, func(t *testing.T) {
+ for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
+ t.Run(optTyp, func(t *testing.T) {
+ h := flow.header4Tuple(outgoing)
+ mcastAddr := h.dstAddr.Addr
+ localIfAddr := h.srcAddr.Addr
+
+ var ifoptSet tcpip.MulticastInterfaceOption
+ switch optTyp {
+ case "use local-addr":
+ ifoptSet.InterfaceAddr = localIfAddr
+ case "use NICID":
+ ifoptSet.NIC = 1
+ case "use local-addr and NIC":
+ ifoptSet.InterfaceAddr = localIfAddr
+ ifoptSet.NIC = 1
+ default:
+ t.Fatal("unknown test variant")
+ }
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch variant {
- case "v4", "mapped":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(flow.sockProto())
+
+ if bindTyp == "bound" {
+ // Bind the socket by connecting to the multicast address.
+ // This may have an influence on how the multicast interface
+ // is set.
+ addr := tcpip.FullAddress{
+ Addr: flow.mapAddrIfApplicable(mcastAddr),
+ Port: stackPort,
+ }
+ if err := c.ep.Connect(addr); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+ }
- b := c.getPacket(networkProtocolNumber, multicast)
- checkerFn(c.t, b,
- checker.TTL(wantTTL),
- checker.UDP(
- checker.DstPort(port),
- ),
- )
- })
- }
- })
-}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
-func TestMulticastInterfaceOption(t *testing.T) {
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- var mcastAddr, localIfAddr tcpip.Address
- switch variant {
- case "v4":
- mcastAddr = multicastAddr
- localIfAddr = stackAddr
- case "mapped":
- mcastAddr = multicastV4MappedAddr
- localIfAddr = stackAddr
- case "v6":
- mcastAddr = multicastV6Addr
- localIfAddr = stackV6Addr
- default:
- t.Fatal("unknown test variant")
- }
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: mcastAddr,
- Port: multicastPort,
+ // Verify multicast interface addr and NIC were set correctly.
+ // Note that NIC must be 1 since this is our outgoing interface.
+ ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
+ var ifoptGot tcpip.MulticastInterfaceOption
+ if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
+ c.t.Fatalf("GetSockOpt failed: %v", err)
}
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ if ifoptGot != ifoptWant {
+ c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
}
- }
-
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
+ })
+ }
+ })
+ }
+ })
+ }
}