summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-09-16 12:19:06 -0700
committergVisor bot <gvisor-bot@google.com>2020-09-16 12:20:45 -0700
commit87c5c0ad2568215675391f6fb7fe335bcae06173 (patch)
tree6ad914dab0e359ab7eed9fdfdb78abb5d852674c /pkg
parent326a1dbb73addeaf8e51af55b8a9de95638b4017 (diff)
Receive broadcast packets on interested endpoints
When a broadcast packet is received by the stack, the packet should be delivered to each endpoint that may be interested in the packet. This includes all any address and specified broadcast address listeners. Test: integration_test.TestReuseAddrAndBroadcast PiperOrigin-RevId: 332060652
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go18
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go120
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go68
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go32
4 files changed, 193 insertions, 45 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..0774b5382 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a TCP packet with a non-unicast source or destination
// address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isInboundUnicast(r *Route) bool {
+ return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r)
+}
+
+func isOutboundUnicast(r *Route) bool {
+ return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress)
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 52c27e045..659acbc7a 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -436,3 +437,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
})
}
}
+
+// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
+// interested endpoints.
+func TestReuseAddrAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 9000
+ loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ )
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+
+ tests := []struct {
+ name string
+ broadcastAddr tcpip.Address
+ }{
+ {
+ name: "Subnet directed broadcast",
+ broadcastAddr: loopbackBroadcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ broadcastAddr: header.IPv4Broadcast,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x7f\x00\x00\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ // We use the empty subnet instead of just the loopback subnet so we
+ // also have a route to the IPv4 Broadcast address.
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // We create endpoints that bind to both the wildcard address and the
+ // broadcast address to make sure both of these types of "broadcast
+ // interested" endpoints receive broadcast packets.
+ wq := waiter.Queue{}
+ var eps []tcpip.Endpoint
+ for _, bindWildcard := range []bool{false, true} {
+ // Create multiple endpoints for each type of "broadcast interested"
+ // endpoint so we can test that all endpoints receive the broadcast
+ // packet.
+ for i := 0; i < 2; i++ {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
+ }
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if bindWildcard {
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ } else {
+ bindAddr.Addr = test.broadcastAddr
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ }
+
+ eps = append(eps, ep)
+ }
+ }
+
+ for i, wep := range eps {
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.broadcastAddr,
+ Port: localPort,
+ },
+ }
+ if n, _, err := wep.Write(data, writeOpts); err != nil {
+ t.Fatalf("eps[%d].Write(_, _): %s", i, err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
+ }
+
+ for j, rep := range eps {
+ if gotPayload, _, err := rep.Read(nil); err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
+ } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 0d13e1efd..b1e5f1b24 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5214,6 +5214,8 @@ func TestListenBacklogFull(t *testing.T) {
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ subnet := context.StackAddrWithPrefix.Subnet()
+ subnetBroadcastAddr := subnet.Broadcast()
tests := []struct {
name string
@@ -5221,53 +5223,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
dstAddr tcpip.Address
}{
{
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
+ name: "SourceUnspecified",
+ srcAddr: header.IPv4Any,
+ dstAddr: context.StackAddr,
},
{
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
+ name: "SourceBroadcast",
+ srcAddr: header.IPv4Broadcast,
+ dstAddr: context.StackAddr,
},
{
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
+ name: "SourceOurMulticast",
+ srcAddr: multicastAddr,
+ dstAddr: context.StackAddr,
},
{
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
+ name: "SourceOtherMulticast",
+ srcAddr: otherMulticastAddr,
+ dstAddr: context.StackAddr,
},
{
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
+ name: "DestUnspecified",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Any,
},
{
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
+ name: "DestBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Broadcast,
},
{
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
+ name: "DestOurMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: multicastAddr,
},
{
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
+ name: "DestOtherMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: otherMulticastAddr,
+ },
+ {
+ name: "SrcSubnetBroadcast",
+ srcAddr: subnetBroadcastAddr,
+ dstAddr: context.StackAddr,
+ },
+ {
+ name: "DestSubnetBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: subnetBroadcastAddr,
},
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5367,11 +5375,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index baf7df197..85e8c1c75 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -53,11 +53,11 @@ const (
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
@@ -73,6 +73,18 @@ const (
testInitialSequenceNumber = 789
)
+// StackAddrWithPrefix is StackAddr with its associated prefix length.
+var StackAddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackAddr,
+ PrefixLen: 24,
+}
+
+// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
+var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackV6Addr,
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+}
+
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -184,12 +196,20 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{