From 48745251611b5c152b1a2b66a0f2f30dd4dc1ed9 Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Thu, 3 Oct 2019 19:30:01 -0700 Subject: Implement proper local broadcast behavior The behavior for sending and receiving local broadcast (255.255.255.255) traffic is as follows: Outgoing -------- * A broadcast packet sent on a socket that is bound to an interface goes out that interface * A broadcast packet sent on an unbound socket follows the route table to select the outgoing interface + if an explicit route entry exists for 255.255.255.255/32, use that one + else use the default route * Broadcast packets are looped back and delivered following the rules for incoming packets (see next). This is the same behavior as for multicast packets, except that it cannot be disabled via sockopt. Incoming -------- * Sockets wishing to receive broadcast packets must bind to either INADDR_ANY (0.0.0.0) or INADDR_BROADCAST (255.255.255.255). No other socket receives broadcast packets. * Broadcast packets are multiplexed to all sockets matching it. This is the same behavior as for multicast packets. * A socket can bind to 255.255.255.255: and then receive its own broadcast packets sent to 255.255.255.255: In addition, this change implicitly fixes an issue with multicast reception. If two sockets want to receive a given multicast stream and one is bound to ANY while the other is bound to the multicast address, only one of them will receive the traffic. PiperOrigin-RevId: 272792377 --- pkg/tcpip/network/ipv4/ipv4_test.go | 3 - pkg/tcpip/stack/nic.go | 27 +- pkg/tcpip/stack/route.go | 13 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/stack/stack_test.go | 123 ++++++- pkg/tcpip/stack/transport_demuxer.go | 47 +-- pkg/tcpip/tcpip.go | 44 +++ pkg/tcpip/tcpip_test.go | 31 ++ pkg/tcpip/transport/udp/udp_test.go | 78 +++- .../socket_ipv4_udp_unbound_external_networking.cc | 408 +++++++++++++++++---- 10 files changed, 636 insertions(+), 140 deletions(-) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index b6641ccc3..a53894c01 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -47,9 +47,6 @@ func TestExcludeBroadcast(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { t.Fatalf("AddAddress failed: %v", err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f6106f762..5993fe582 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -104,6 +104,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback func (n *NIC) enable() *tcpip.Error { n.attachLinkEndpoint() + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -372,7 +382,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { - // Don't include expired or tempory endpoints to avoid confusion and + // Don't include expired or temporary endpoints to avoid confusion and // prevent the caller from using those. switch ref.getKind() { case permanentExpired, temporary: @@ -624,21 +634,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.AddLinkAddress(n.id, src, remote) - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) - } - } - n.mu.RUnlock() - return - } - if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5c8b7977a..0b09e6517 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop = PacketLoop } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop } return Route{ @@ -208,10 +210,17 @@ func (r *Route) Clone() Route { return *r } -// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +// MakeLoopedRoute duplicates the given route with special handling for routes +// used for sending multicast or broadcast packets. In those cases the +// multicast/broadcast address is the remote address when sending out, but for +// incoming (looped) packets it becomes the local address. Similarly, the local +// interface address that was the local address going out becomes the remote +// address coming in. This is different to unicast routes where local and +// remote addresses remain the same as they identify location (local vs remote) +// not direction (source vs destination). func (r *Route) MakeLoopedRoute() Route { l := r.Clone() - if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress l.RemoteLinkAddress = l.LocalLinkAddress } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 90c2cf1be..ff574a055 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -902,7 +902,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2dede8a9..a2e0a6e7b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -952,10 +952,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. testSendTo(t, s, dstAddr, ep, nil) @@ -967,10 +967,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. testSend(t, r, ep, nil) @@ -1016,17 +1016,33 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } -func TestBroadcastNeedsNoRoute(t *testing.T) { +func verifyRoute(gotRoute, wantRoute stack.Route) error { + if gotRoute.LocalAddress != wantRoute.LocalAddress { + return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) + } + if gotRoute.RemoteAddress != wantRoute.RemoteAddress { + return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) + } + if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { + return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + } + if gotRoute.NextHop != wantRoute.NextHop { + return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) + } + return nil +} + +func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) @@ -1039,28 +1055,99 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + if err := s.AddProtocolAddress(1, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) + // If the NIC doesn't exist, it won't work. + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { + t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } +} + +func TestOutgoingBroadcastWithRouteTable(t *testing.T) { + defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. + nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. + nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) + // Create a new stack with two NICs. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + if err := s.CreateNIC(2, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { + t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + } + + // Set the initial route table. + rt := []tcpip.Route{ + {Destination: nic1Addr.Subnet(), NIC: 1}, + {Destination: nic2Addr.Subnet(), NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, + } + s.SetRouteTable(rt) + + // When an interface is given, the route for a broadcast goes through it. + r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + + // When an interface is not given, it consults the route table. + // 1. Case: Using the default route. + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) + } + + // 2. Case: Having an explicit route for broadcast will select that one. + rt = append( + []tcpip.Route{ + {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + }, + rt..., + ) + s.SetRouteTable(rt) + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 8c768c299..92267ce4d 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -63,7 +63,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, vv) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return @@ -338,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // If a sender bound to the Loopback interface sends a broadcast, - // that broadcast must not be delivered to the sender. - if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { - return false - } - - // If the packet is a broadcast, then find all matching transport endpoints. - // Otherwise, try to find a single matching transport endpoint. - destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() - if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { - for epID, endpoint := range eps.endpoints { - if epID.LocalPort == id.LocalPort { - destEps = append(destEps, endpoint) - } - } + // Determine which transport endpoint or endpoints to deliver this packet to. + // If the packet is a broadcast or multicast, then find all matching + // transport endpoints. + var destEps []*endpointsByNic + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, vv, id) } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } @@ -426,10 +417,11 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the local address. @@ -437,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the remote part. @@ -445,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given +// id. +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { + return matchedEPs[0] + } return nil } @@ -491,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } } } + +func isMulticastOrBroadcast(addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index faaa4a4e3..70e7575f5 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -57,6 +57,9 @@ type Error struct { // String implements fmt.Stringer.String. func (e *Error) String() string { + if e == nil { + return "" + } return e.msg } @@ -1095,6 +1098,47 @@ func (a AddressWithPrefix) String() string { return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) } +// Subnet converts the address and prefix into a Subnet value and returns it. +func (a AddressWithPrefix) Subnet() Subnet { + addrLen := len(a.Address) + if a.PrefixLen <= 0 { + return Subnet{ + address: Address(strings.Repeat("\x00", addrLen)), + mask: AddressMask(strings.Repeat("\x00", addrLen)), + } + } + if a.PrefixLen >= addrLen*8 { + return Subnet{ + address: a.Address, + mask: AddressMask(strings.Repeat("\xff", addrLen)), + } + } + + sa := make([]byte, addrLen) + sm := make([]byte, addrLen) + n := uint(a.PrefixLen) + for i := 0; i < addrLen; i++ { + if n >= 8 { + sa[i] = a.Address[i] + sm[i] = 0xff + n -= 8 + continue + } + sm[i] = ^byte(0xff >> n) + sa[i] = a.Address[i] & sm[i] + n = 0 + } + + // For extra caution, call NewSubnet rather than directly creating the Subnet + // value. If that fails it indicates a serious bug in this code, so panic is + // in order. + s, err := NewSubnet(Address(sa), AddressMask(sm)) + if err != nil { + panic("invalid subnet: " + err.Error()) + } + return s +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index fb3a0a5ee..8c0aacffa 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -195,3 +195,34 @@ func TestStatsString(t *testing.T) { t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) } } + +func TestAddressWithPrefixSubnet(t *testing.T) { + tests := []struct { + addr Address + prefixLen int + subnetAddr Address + subnetMask AddressMask + }{ + {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, + {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, + {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + } + for _, tt := range tests { + ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} + gotSubnet := ap.Subnet() + wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) + if err != nil { + t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + continue + } + if gotSubnet != wantSubnet { + t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) + } + } +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 5059ca22d..d1ba8d2ce 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -532,10 +532,11 @@ func TestBindToDeviceOption(t *testing.T) { } } -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// testReadInternal sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then attempts to read it from the +// UDP endpoint and depending on if this was expected to succeed verifies its +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped bool) { c.t.Helper() payload := newPayload() @@ -553,27 +554,51 @@ func testRead(c *testContext, flow testFlow) { select { case <-ch: v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") + case <-time.After(300 * time.Millisecond): + if packetShouldBeDropped { + return // expected to time out + } + c.t.Fatal("timed out waiting for data") } } + if err != nil { + c.t.Fatal("Read failed:", err) + } + + if packetShouldBeDropped { + c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + } + // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + c.t.Fatalf("bad payload: got %x, want %x", v, payload) } } +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, false /* packetShouldBeDropped */) +} + +// testFailingRead sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then tries to read it from the UDP +// endpoint and expects this to fail. +func testFailingRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, true /* packetShouldBeDropped */) +} + func TestBindEphemeralPort(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -743,13 +768,17 @@ func TestReadOnBoundToMulticast(t *testing.T) { c.t.Fatal("SetSockOpt failed:", err) } + // Check that we receive multicast packets but not unicast or broadcast + // ones. testRead(c, flow) + testFailingRead(c, broadcast) + testFailingRead(c, unicastV4) }) } } // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and receive broadcast data on it. +// address and can receive only broadcast data. func TestV4ReadOnBoundToBroadcast(t *testing.T) { for _, flow := range []testFlow{broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -764,8 +793,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - // Test acceptance. + // Check that we receive broadcast packets but not unicast ones. + testRead(c, flow) + testFailingRead(c, unicastV4) + }) + } +} + +// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY +// and receive broadcast and unicast data. +func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s (", err) + } + + // Check that we receive both broadcast and unicast packets. testRead(c, flow) + testRead(c, unicastV4) }) } } diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index c85ae30dc..8b8993d3d 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -42,6 +42,26 @@ TestAddress V4EmptyAddress() { return t; } +constexpr char kMulticastAddress[] = "224.0.2.1"; + +TestAddress V4Multicast() { + TestAddress t("V4Multicast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast(&t.addr)->sin_addr.s_addr = + inet_addr(kMulticastAddress); + return t; +} + +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast(&t.addr)->sin_addr.s_addr = + htonl(INADDR_BROADCAST); + return t; +} + void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { got_if_infos_ = false; @@ -116,7 +136,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) { // Verifies that a broadcast UDP packet will arrive at all UDP sockets with // the destination port number. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastReceivedOnAllExpectedEndpoints) { + UDPBroadcastReceivedOnExpectedPort) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -136,51 +156,134 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); - sockaddr_in rcv_addr = {}; - socklen_t rcv_addr_sz = sizeof(rcv_addr); - rcv_addr.sin_family = AF_INET; - rcv_addr.sin_addr.s_addr = htonl(INADDR_ANY); - ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast(&rcv_addr), - rcv_addr_sz), + // Bind the first socket to the ANY address and let the system assign a port. + auto rcv1_addr = V4Any(); + ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast(&rcv1_addr.addr), + rcv1_addr.addr_len), SyscallSucceedsWithValue(0)); // Retrieve port number from first socket so that it can be bound to the // second socket. - rcv_addr = {}; + socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - getsockname(rcvr1->get(), reinterpret_cast(&rcv_addr), + getsockname(rcvr1->get(), reinterpret_cast(&rcv1_addr.addr), &rcv_addr_sz), SyscallSucceedsWithValue(0)); - ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast(&rcv_addr), + EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); + auto port = reinterpret_cast(&rcv1_addr.addr)->sin_port; + + // Bind the second socket to the same address:port as the first. + ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast(&rcv1_addr.addr), rcv_addr_sz), SyscallSucceedsWithValue(0)); // Bind the non-receiving socket to an ephemeral port. - sockaddr_in norcv_addr = {}; - norcv_addr.sin_family = AF_INET; - norcv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + auto norecv_addr = V4Any(); + ASSERT_THAT(bind(norcv->get(), reinterpret_cast(&norecv_addr.addr), + norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); + + // Broadcast a test message. + auto dst_addr = V4Broadcast(); + reinterpret_cast(&dst_addr.addr)->sin_port = port; + constexpr char kTestMsg[] = "hello, world"; + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the receiving sockets received the test message. + char buf[sizeof(kTestMsg)] = {}; + EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); + memset(buf, 0, sizeof(buf)); + EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); + + // Verify that the non-receiving socket did not receive the test message. + memset(buf, 0, sizeof(buf)); + EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Verifies that a broadcast UDP packet will arrive at all UDP sockets bound to +// the destination port number and either INADDR_ANY or INADDR_BROADCAST, but +// not a unicast address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + UDPBroadcastReceivedOnExpectedAddresses) { + // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its + // IPv4 address on eth0. + SKIP_IF(!got_if_infos_); + + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Enable SO_BROADCAST on the sending socket. + ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Enable SO_REUSEPORT on all sockets so that they may all be bound to the + // broadcast messages destination port. + ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(setsockopt(norcv->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Bind the first socket the ANY address and let the system assign a port. + auto rcv1_addr = V4Any(); + ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast(&rcv1_addr.addr), + rcv1_addr.addr_len), + SyscallSucceedsWithValue(0)); + // Retrieve port number from first socket so that it can be bound to the + // second socket. + socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - bind(norcv->get(), reinterpret_cast(&norcv_addr), - sizeof(norcv_addr)), + getsockname(rcvr1->get(), reinterpret_cast(&rcv1_addr.addr), + &rcv_addr_sz), SyscallSucceedsWithValue(0)); + EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); + auto port = reinterpret_cast(&rcv1_addr.addr)->sin_port; + + // Bind the second socket to the broadcast address. + auto rcv2_addr = V4Broadcast(); + reinterpret_cast(&rcv2_addr.addr)->sin_port = port; + ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast(&rcv2_addr.addr), + rcv2_addr.addr_len), + SyscallSucceedsWithValue(0)); + + // Bind the non-receiving socket to the unicast ethernet address. + auto norecv_addr = rcv1_addr; + reinterpret_cast(&norecv_addr.addr)->sin_addr = + eth_if_sin_addr_; + ASSERT_THAT(bind(norcv->get(), reinterpret_cast(&norecv_addr.addr), + norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); // Broadcast a test message. - sockaddr_in dst_addr = {}; - dst_addr.sin_family = AF_INET; - dst_addr.sin_addr.s_addr = htonl(INADDR_BROADCAST); - dst_addr.sin_port = rcv_addr.sin_port; + auto dst_addr = V4Broadcast(); + reinterpret_cast(&dst_addr.addr)->sin_port = port; constexpr char kTestMsg[] = "hello, world"; EXPECT_THAT( sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast(&dst_addr), sizeof(dst_addr)), + reinterpret_cast(&dst_addr.addr), dst_addr.addr_len), SyscallSucceedsWithValue(sizeof(kTestMsg))); // Verify that the receiving sockets received the test message. char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(read(rcvr1->get(), buf, sizeof(buf)), + EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); memset(buf, 0, sizeof(buf)); - EXPECT_THAT(read(rcvr2->get(), buf, sizeof(buf)), + EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); @@ -190,10 +293,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SyscallFailsWithErrno(EAGAIN)); } -// Verifies that a UDP broadcast sent via the loopback interface is not received -// by the sender. +// Verifies that a UDP broadcast can be sent and then received back on the same +// socket that is bound to the broadcast address (255.255.255.255). +// FIXME(b/141938460): This can be combined with the next test +// (UDPBroadcastSendRecvOnSocketBoundToAny). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastViaLoopbackFails) { + UDPBroadcastSendRecvOnSocketBoundToBroadcast) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -201,33 +306,73 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); - // Bind the sender to the loopback interface. - sockaddr_in src = {}; - socklen_t src_sz = sizeof(src); - src.sin_family = AF_INET; - src.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT( - bind(sender->get(), reinterpret_cast(&src), src_sz), - SyscallSucceedsWithValue(0)); + // Bind the sender to the broadcast address. + auto src_addr = V4Broadcast(); + ASSERT_THAT(bind(sender->get(), reinterpret_cast(&src_addr.addr), + src_addr.addr_len), + SyscallSucceedsWithValue(0)); + socklen_t src_sz = src_addr.addr_len; ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast(&src), &src_sz), + reinterpret_cast(&src_addr.addr), &src_sz), SyscallSucceedsWithValue(0)); - ASSERT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(src_sz, src_addr.addr_len); // Send the message. - sockaddr_in dst = {}; - dst.sin_family = AF_INET; - dst.sin_addr.s_addr = htonl(INADDR_BROADCAST); - dst.sin_port = src.sin_port; + auto dst_addr = V4Broadcast(); + reinterpret_cast(&dst_addr.addr)->sin_port = + reinterpret_cast(&src_addr.addr)->sin_port; constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast(&dst), sizeof(dst)), + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the message was received. + char buf[sizeof(kTestMsg)] = {}; + EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); +} + +// Verifies that a UDP broadcast can be sent and then received back on the same +// socket that is bound to the ANY address (0.0.0.0). +// FIXME(b/141938460): This can be combined with the previous test +// (UDPBroadcastSendRecvOnSocketBoundToBroadcast). +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + UDPBroadcastSendRecvOnSocketBoundToAny) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - // Verify that the message was not received by the sender (loopback). + // Enable SO_BROADCAST. + ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Bind the sender to the ANY address. + auto src_addr = V4Any(); + ASSERT_THAT(bind(sender->get(), reinterpret_cast(&src_addr.addr), + src_addr.addr_len), + SyscallSucceedsWithValue(0)); + socklen_t src_sz = src_addr.addr_len; + ASSERT_THAT(getsockname(sender->get(), + reinterpret_cast(&src_addr.addr), &src_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(src_sz, src_addr.addr_len); + + // Send the message. + auto dst_addr = V4Broadcast(); + reinterpret_cast(&dst_addr.addr)->sin_port = + reinterpret_cast(&src_addr.addr)->sin_port; + constexpr char kTestMsg[] = "hello, world"; + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the message was received. char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); } // Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST @@ -237,15 +382,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { // Broadcast a test message without having enabled SO_BROADCAST on the sending // socket. - sockaddr_in addr = {}; - socklen_t addr_sz = sizeof(addr); - addr.sin_family = AF_INET; - addr.sin_port = htons(12345); - addr.sin_addr.s_addr = htonl(INADDR_BROADCAST); + auto addr = V4Broadcast(); + reinterpret_cast(&addr.addr)->sin_port = htons(12345); constexpr char kTestMsg[] = "hello, world"; EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast(&addr), addr_sz), + reinterpret_cast(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EACCES)); } @@ -274,21 +416,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) { reinterpret_cast(&addr), addr_sz), SyscallSucceedsWithValue(sizeof(kTestMsg))); char buf[sizeof(kTestMsg)] = {}; - ASSERT_THAT(read(rcvr->get(), buf, sizeof(buf)), + ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); } -constexpr char kMulticastAddress[] = "224.0.2.1"; - -TestAddress V4Multicast() { - TestAddress t("V4Multicast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast(&t.addr)->sin_addr.s_addr = - inet_addr(kMulticastAddress); - return t; -} - // Check that multicast packets won't be delivered to the sending socket with no // set interface or group membership. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, @@ -609,8 +740,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, } // Check that two sockets can join the same multicast group at the same time, -// and both will receive data on it. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { +// and both will receive data on it when bound to the ANY address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToAny) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -624,8 +756,72 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - // Bind the receiver to the v4 any address to ensure that we can receive the - // multicast packet. + // Bind to ANY to receive multicast packets. + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + EXPECT_EQ( + htonl(INADDR_ANY), + reinterpret_cast(&receiver_addr.addr)->sin_addr.s_addr); + // On the first iteration, save the port we are bound to. On the second + // iteration, verify the port is the same as the one from the first + // iteration. In other words, both sockets listen on the same port. + if (bound_port == 0) { + bound_port = + reinterpret_cast(&receiver_addr.addr)->sin_port; + } else { + EXPECT_EQ(bound_port, + reinterpret_cast(&receiver_addr.addr)->sin_port); + } + + // Register to receive multicast packets. + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + } + + // Send a multicast packet to the group and verify both receivers get it. + auto send_addr = V4Multicast(); + reinterpret_cast(&send_addr.addr)->sin_port = bound_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + for (auto& receiver : receivers) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } +} + +// Check that two sockets can join the same multicast group at the same time, +// and both will receive data on it when bound to the multicast address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToMulticastAddress) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + std::unique_ptr receivers[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + auto receiver_addr = V4Multicast(); + int bound_port = 0; + for (auto& receiver : receivers) { + ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); ASSERT_THAT( bind(receiver->get(), reinterpret_cast(&receiver_addr.addr), receiver_addr.addr_len), @@ -636,6 +832,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast(&receiver_addr.addr)->sin_addr.s_addr); // On the first iteration, save the port we are bound to. On the second // iteration, verify the port is the same as the one from the first // iteration. In other words, both sockets listen on the same port. @@ -643,6 +842,83 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { bound_port = reinterpret_cast(&receiver_addr.addr)->sin_port; } else { + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast(&receiver_addr.addr)->sin_addr.s_addr); + EXPECT_EQ(bound_port, + reinterpret_cast(&receiver_addr.addr)->sin_port); + } + + // Register to receive multicast packets. + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + } + + // Send a multicast packet to the group and verify both receivers get it. + auto send_addr = V4Multicast(); + reinterpret_cast(&send_addr.addr)->sin_port = bound_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + for (auto& receiver : receivers) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } +} + +// Check that two sockets can join the same multicast group at the same time, +// and with one bound to the wildcard address and the other bound to the +// multicast address, both will receive data. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToAnyAndMulticastAddress) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + std::unique_ptr receivers[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + // The first receiver binds to the wildcard address. + auto receiver_addr = V4Any(); + int bound_port = 0; + for (auto& receiver : receivers) { + ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + // On the first iteration, save the port we are bound to and change the + // receiver address from V4Any to V4Multicast so the second receiver binds + // to that. On the second iteration, verify the port is the same as the one + // from the first iteration but the address is different. + if (bound_port == 0) { + EXPECT_EQ( + htonl(INADDR_ANY), + reinterpret_cast(&receiver_addr.addr)->sin_addr.s_addr); + bound_port = + reinterpret_cast(&receiver_addr.addr)->sin_port; + receiver_addr = V4Multicast(); + reinterpret_cast(&receiver_addr.addr)->sin_port = + bound_port; + } else { + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast(&receiver_addr.addr)->sin_addr.s_addr); EXPECT_EQ(bound_port, reinterpret_cast(&receiver_addr.addr)->sin_port); } -- cgit v1.2.3