summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-01-06 11:39:14 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-06 11:41:42 -0800
commitabe9d9f67f2c2c696ef26690fa8518dfc4e28728 (patch)
treeeabba6c15b22345ab322f7380548c4c8eb5e50fe /pkg
parent0c4118d5b8428648c252df55a7867ac6f287f146 (diff)
Support add/remove IPv6 multicast group sock opt
IPv4 was always supported but UDP never supported joining/leaving IPv6 multicast groups via socket options. Add: IPPROTO_IPV6, IPV6_JOIN_GROUP/IPV6_ADD_MEMBERSHIP Remove: IPPROTO_IPV6, IPV6_LEAVE_GROUP/IPV6_DROP_MEMBERSHIP Test: integration_test.TestUDPAddRemoveMembershipSocketOption PiperOrigin-RevId: 350396072
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/socket.go6
-rw-r--r--pkg/sentry/socket/netstack/netstack.go37
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go474
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go14
4 files changed, 372 insertions, 159 deletions
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 556892dc3..8591acbf2 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -250,6 +250,12 @@ type SockAddrInet struct {
_ [8]uint8 // pad to sizeof(struct sockaddr).
}
+// Inet6MulticastRequest is struct ipv6_mreq, from uapi/linux/in6.h.
+type Inet6MulticastRequest struct {
+ MulticastAddr Inet6Addr
+ InterfaceIndex int32
+}
+
// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
type InetMulticastRequest struct {
MulticastAddr InetAddr
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 3f587638f..fe11fca9c 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2091,9 +2091,29 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetV6Only(v != 0)
return nil
- case linux.IPV6_ADD_MEMBERSHIP,
- linux.IPV6_DROP_MEMBERSHIP,
- linux.IPV6_IPSEC_POLICY,
+ case linux.IPV6_ADD_MEMBERSHIP:
+ req, err := copyInMulticastV6Request(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IPV6_DROP_MEMBERSHIP:
+ req, err := copyInMulticastV6Request(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
// TODO(b/148887420): Add support for IPV6_PKTINFO.
@@ -2181,6 +2201,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
var (
inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
+ inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{}))
)
// copyInMulticastRequest copies in a variable-size multicast request. The
@@ -2214,6 +2235,16 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
return req, nil
}
+func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syserr.Error) {
+ if len(optVal) < inet6MulticastRequestSize {
+ return linux.Inet6MulticastRequest{}, syserr.ErrInvalidArgument
+ }
+
+ var req linux.Inet6MulticastRequest
+ binary.Unmarshal(optVal[:inet6MulticastRequestSize], usermem.ByteOrder, &req)
+ return req, nil
+}
+
// parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf.
//
// net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options.
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 2e59f6a42..20f8a7e6c 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -35,6 +35,9 @@ import (
const (
defaultMTU = 1280
ttl = 255
+
+ remotePort = 5555
+ localPort = 80
)
var (
@@ -151,11 +154,11 @@ func TestPingMulticastBroadcast(t *testing.T) {
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -215,163 +218,219 @@ func TestPingMulticastBroadcast(t *testing.T) {
}
+func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ totalLen := header.IPv4MinimumSize + payloadLen
+ hdr := buffer.NewPrependable(totalLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(udp.ProtocolNumber),
+ TTL: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
+func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLen),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
// multicast or broadcast address.
func TestIncomingMulticastAndBroadcast(t *testing.T) {
- const (
- nicID = 1
- remotePort = 5555
- localPort = 80
- )
+ const nicID = 1
data := []byte{1, 2, 3, 4}
- rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) {
- payloadLen := header.UDPMinimumSize + len(data)
- totalLen := header.IPv4MinimumSize + payloadLen
- hdr := buffer.NewPrependable(totalLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: remotePort,
- DstPort: localPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(udp.ProtocolNumber),
- TTL: ttl,
- SrcAddr: remoteIPv4Addr,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) {
- payloadLen := header.UDPMinimumSize + len(data)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: remotePort,
- DstPort: localPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: ttl,
- SrcAddr: remoteIPv6Addr,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
tests := []struct {
- name string
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectRx bool
+ name string
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
+ rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectRx bool
}{
{
- name: "IPv4 unicast binding to unicast",
- bindAddr: ipv4Addr.Address,
- dstAddr: ipv4Addr.Address,
- expectRx: true,
+ name: "IPv4 unicast binding to unicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4Addr.Address,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
},
{
- name: "IPv4 unicast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: ipv4Addr.Address,
- expectRx: false,
+ name: "IPv4 unicast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4Addr.Address,
+ expectRx: false,
},
{
- name: "IPv4 unicast binding to wildcard",
- dstAddr: ipv4Addr.Address,
- expectRx: true,
+ name: "IPv4 unicast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
},
{
- name: "IPv4 directed broadcast binding to subnet broadcast",
- bindAddr: ipv4SubnetBcast,
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 directed broadcast binding to subnet broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 directed broadcast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: ipv4SubnetBcast,
- expectRx: false,
+ name: "IPv4 directed broadcast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: false,
},
{
- name: "IPv4 directed broadcast binding to wildcard",
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 directed broadcast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 broadcast binding to broadcast",
- bindAddr: header.IPv4Broadcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: true,
+ name: "IPv4 broadcast binding to broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: true,
},
{
- name: "IPv4 broadcast binding to subnet broadcast",
- bindAddr: ipv4SubnetBcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: false,
+ name: "IPv4 broadcast binding to subnet broadcast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: false,
},
{
- name: "IPv4 broadcast binding to wildcard",
- dstAddr: ipv4SubnetBcast,
- expectRx: true,
+ name: "IPv4 broadcast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to all-systems multicast",
- bindAddr: header.IPv4AllSystems,
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
+ name: "IPv4 all-systems multicast binding to all-systems multicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: header.IPv4AllSystems,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to wildcard",
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
+ name: "IPv4 all-systems multicast binding to wildcard",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
},
{
- name: "IPv4 all-systems multicast binding to unicast",
- bindAddr: ipv4Addr.Address,
- dstAddr: header.IPv4AllSystems,
- expectRx: false,
+ name: "IPv4 all-systems multicast binding to unicast",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ bindAddr: ipv4Addr.Address,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: false,
},
// IPv6 has no notion of a broadcast.
{
- name: "IPv6 unicast binding to wildcard",
- dstAddr: ipv6Addr.Address,
- expectRx: true,
+ name: "IPv6 unicast binding to wildcard",
+ dstAddr: ipv6Addr.Address,
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ expectRx: true,
},
{
- name: "IPv6 broadcast-like address binding to wildcard",
- dstAddr: ipv6SubnetBcast,
- expectRx: false,
+ name: "IPv6 broadcast-like address binding to wildcard",
+ dstAddr: ipv6SubnetBcast,
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ expectRx: false,
},
}
@@ -385,41 +444,24 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
- }
- ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
- }
-
- var netproto tcpip.NetworkProtocolNumber
- var rxUDP func(*channel.Endpoint, tcpip.Address)
- switch l := len(test.dstAddr); l {
- case header.IPv4AddressSize:
- netproto = header.IPv4ProtocolNumber
- rxUDP = rxIPv4UDP
- case header.IPv6AddressSize:
- netproto = header.IPv6ProtocolNumber
- rxUDP = rxIPv6UDP
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
+ protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
}
defer ep.Close()
bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%+v): %s", bindAddr, err)
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
}
- rxUDP(e, test.dstAddr)
+ test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
if gotPayload, _, err := ep.Read(nil); test.expectRx {
if err != nil {
t.Fatalf("Read(nil): %s", err)
@@ -476,7 +518,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
},
}
if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -516,12 +558,12 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
bindAddr := tcpip.FullAddress{Port: localPort}
if bindWildcard {
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
}
} else {
bindAddr.Addr = test.broadcastAddr
if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
}
}
@@ -557,3 +599,143 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
})
}
}
+
+func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ data := []byte{1, 2, 3, 4}
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
+ rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
+ multicastAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast binding to unicast",
+ multicastAddr: "\xe0\x01\x02\x03",
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: remoteIPv4Addr,
+ localAddr: ipv4Addr,
+ rxUDP: rxIPv4UDP,
+ },
+ {
+ name: "IPv6 broadcast-like address binding to wildcard",
+ multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04",
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: remoteIPv6Addr,
+ localAddr: ipv6Addr,
+ rxUDP: rxIPv6UDP,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ specifyNICID bool
+ specifyNICAddr bool
+ }{
+ {
+ name: "Specify NIC ID and NIC address",
+ specifyNICID: true,
+ specifyNICAddr: true,
+ },
+ {
+ name: "Don't specify NIC ID or NIC address",
+ specifyNICID: false,
+ specifyNICAddr: false,
+ },
+ {
+ name: "Specify NIC ID but don't specify NIC address",
+ specifyNICID: true,
+ specifyNICAddr: false,
+ },
+ {
+ name: "Don't specify NIC ID but specify NIC address",
+ specifyNICID: false,
+ specifyNICAddr: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ }
+
+ // Set the route table so that UDP can find a NIC that is
+ // routable to the multicast address when the NIC isn't specified.
+ if !subTest.specifyNICID && !subTest.specifyNICAddr {
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+
+ memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr}
+ if subTest.specifyNICID {
+ memOpt.NIC = nicID
+ }
+ if subTest.specifyNICAddr {
+ memOpt.InterfaceAddr = test.localAddr.Address
+ }
+
+ // We should receive UDP packets to the group once we join the
+ // multicast group.
+ addOpt := tcpip.AddMembershipOption(memOpt)
+ if err := ep.SetSockOpt(&addOpt); err != nil {
+ t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
+ }
+ test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
+ if gotPayload, _, err := ep.Read(nil); err != nil {
+ t.Fatalf("ep.Read(nil): %s", err)
+ } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+
+ // We should not receive UDP packets to the group once we leave
+ // the multicast group.
+ removeOpt := tcpip.RemoveMembershipOption(memOpt)
+ if err := ep.SetSockOpt(&removeOpt); err != nil {
+ t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
+ }
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got ep.Read(nil) = (%x, _, %s), want = (nil, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9b9e4deb0..4e8bd8b04 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -708,14 +708,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
nicID := v.NIC
- // The interface address is considered not-set if it is empty or contains
- // all-zeros. The former represent the zero-value in golang, the latter the
- // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
- allZeros := header.IPv4Any
- if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
+ if v.InterfaceAddr.Unspecified() {
if nicID == 0 {
- r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err == nil {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
nicID = r.NICID()
r.Release()
}
@@ -748,10 +743,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
nicID := v.NIC
- if v.InterfaceAddr == header.IPv4Any {
+ if v.InterfaceAddr.Unspecified() {
if nicID == 0 {
- r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err == nil {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
nicID = r.NICID()
r.Release()
}