diff options
author | Chris Kuiper <ckuiper@google.com> | 2019-10-03 19:30:01 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-10-03 19:31:35 -0700 |
commit | 48745251611b5c152b1a2b66a0f2f30dd4dc1ed9 (patch) | |
tree | 38f4113e257a3b97d0d1da5ff49d25ce015f56c9 /pkg/tcpip/stack | |
parent | 135aadb5179c94972504910e66e9e2540c09d489 (diff) |
Implement proper local broadcast behavior
The behavior for sending and receiving local broadcast (255.255.255.255)
traffic is as follows:
Outgoing
--------
* A broadcast packet sent on a socket that is bound to an interface goes out
that interface
* A broadcast packet sent on an unbound socket follows the route table to
select the outgoing interface
+ if an explicit route entry exists for 255.255.255.255/32, use that one
+ else use the default route
* Broadcast packets are looped back and delivered following the rules for
incoming packets (see next). This is the same behavior as for multicast
packets, except that it cannot be disabled via sockopt.
Incoming
--------
* Sockets wishing to receive broadcast packets must bind to either INADDR_ANY
(0.0.0.0) or INADDR_BROADCAST (255.255.255.255). No other socket receives
broadcast packets.
* Broadcast packets are multiplexed to all sockets matching it. This is the
same behavior as for multicast packets.
* A socket can bind to 255.255.255.255:<port> and then receive its own
broadcast packets sent to 255.255.255.255:<port>
In addition, this change implicitly fixes an issue with multicast reception. If
two sockets want to receive a given multicast stream and one is bound to ANY
while the other is bound to the multicast address, only one of them will
receive the traffic.
PiperOrigin-RevId: 272792377
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 123 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 47 |
5 files changed, 154 insertions, 58 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f6106f762..5993fe582 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -104,6 +104,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback func (n *NIC) enable() *tcpip.Error { n.attachLinkEndpoint() + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -372,7 +382,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { - // Don't include expired or tempory endpoints to avoid confusion and + // Don't include expired or temporary endpoints to avoid confusion and // prevent the caller from using those. switch ref.getKind() { case permanentExpired, temporary: @@ -624,21 +634,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.AddLinkAddress(n.id, src, remote) - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) - } - } - n.mu.RUnlock() - return - } - if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5c8b7977a..0b09e6517 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop = PacketLoop } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop } return Route{ @@ -208,10 +210,17 @@ func (r *Route) Clone() Route { return *r } -// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +// MakeLoopedRoute duplicates the given route with special handling for routes +// used for sending multicast or broadcast packets. In those cases the +// multicast/broadcast address is the remote address when sending out, but for +// incoming (looped) packets it becomes the local address. Similarly, the local +// interface address that was the local address going out becomes the remote +// address coming in. This is different to unicast routes where local and +// remote addresses remain the same as they identify location (local vs remote) +// not direction (source vs destination). func (r *Route) MakeLoopedRoute() Route { l := r.Clone() - if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress l.RemoteLinkAddress = l.LocalLinkAddress } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 90c2cf1be..ff574a055 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -902,7 +902,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2dede8a9..a2e0a6e7b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -952,10 +952,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. testSendTo(t, s, dstAddr, ep, nil) @@ -967,10 +967,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. testSend(t, r, ep, nil) @@ -1016,17 +1016,33 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } -func TestBroadcastNeedsNoRoute(t *testing.T) { +func verifyRoute(gotRoute, wantRoute stack.Route) error { + if gotRoute.LocalAddress != wantRoute.LocalAddress { + return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) + } + if gotRoute.RemoteAddress != wantRoute.RemoteAddress { + return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) + } + if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { + return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + } + if gotRoute.NextHop != wantRoute.NextHop { + return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) + } + return nil +} + +func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) @@ -1039,28 +1055,99 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + if err := s.AddProtocolAddress(1, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) + // If the NIC doesn't exist, it won't work. + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { + t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } +} + +func TestOutgoingBroadcastWithRouteTable(t *testing.T) { + defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. + nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. + nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) + // Create a new stack with two NICs. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + if err := s.CreateNIC(2, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) } - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { + t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + } + + // Set the initial route table. + rt := []tcpip.Route{ + {Destination: nic1Addr.Subnet(), NIC: 1}, + {Destination: nic2Addr.Subnet(), NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, + } + s.SetRouteTable(rt) + + // When an interface is given, the route for a broadcast goes through it. + r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + + // When an interface is not given, it consults the route table. + // 1. Case: Using the default route. + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) + } + + // 2. Case: Having an explicit route for broadcast will select that one. + rt = append( + []tcpip.Route{ + {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + }, + rt..., + ) + s.SetRouteTable(rt) + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 8c768c299..92267ce4d 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -63,7 +63,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, vv) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return @@ -338,23 +338,14 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // If a sender bound to the Loopback interface sends a broadcast, - // that broadcast must not be delivered to the sender. - if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { - return false - } - - // If the packet is a broadcast, then find all matching transport endpoints. - // Otherwise, try to find a single matching transport endpoint. - destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() - if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { - for epID, endpoint := range eps.endpoints { - if epID.LocalPort == id.LocalPort { - destEps = append(destEps, endpoint) - } - } + // Determine which transport endpoint or endpoints to deliver this packet to. + // If the packet is a broadcast or multicast, then find all matching + // transport endpoints. + var destEps []*endpointsByNic + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, vv, id) } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { destEps = append(destEps, ep) } @@ -426,10 +417,11 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the local address. @@ -437,7 +429,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the remote part. @@ -445,15 +437,24 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given +// id. +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { + return matchedEPs[0] + } return nil } @@ -491,3 +492,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } } } + +func isMulticastOrBroadcast(addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +} |