diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 120 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 32 |
4 files changed, 193 insertions, 45 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..0774b5382 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a TCP packet with a non-unicast source or destination // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isInboundUnicast(r *Route) bool { + return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r) +} + +func isOutboundUnicast(r *Route) bool { + return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 52c27e045..659acbc7a 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -436,3 +437,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { }) } } + +// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all +// interested endpoints. +func TestReuseAddrAndBroadcast(t *testing.T) { + const ( + nicID = 1 + localPort = 9000 + loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + ) + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + + tests := []struct { + name string + broadcastAddr tcpip.Address + }{ + { + name: "Subnet directed broadcast", + broadcastAddr: loopbackBroadcast, + }, + { + name: "IPv4 broadcast", + broadcastAddr: header.IPv4Broadcast, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x7f\x00\x00\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + // We use the empty subnet instead of just the loopback subnet so we + // also have a route to the IPv4 Broadcast address. + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // We create endpoints that bind to both the wildcard address and the + // broadcast address to make sure both of these types of "broadcast + // interested" endpoints receive broadcast packets. + wq := waiter.Queue{} + var eps []tcpip.Endpoint + for _, bindWildcard := range []bool{false, true} { + // Create multiple endpoints for each type of "broadcast interested" + // endpoint so we can test that all endpoints receive the broadcast + // packet. + for i := 0; i < 2; i++ { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + defer ep.Close() + + if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) + } + + bindAddr := tcpip.FullAddress{Port: localPort} + if bindWildcard { + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } else { + bindAddr.Addr = test.broadcastAddr + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } + + eps = append(eps, ep) + } + } + + for i, wep := range eps { + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.broadcastAddr, + Port: localPort, + }, + } + if n, _, err := wep.Write(data, writeOpts); err != nil { + t.Fatalf("eps[%d].Write(_, _): %s", i, err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) + } + + for j, rep := range eps { + if gotPayload, _, err := rep.Read(nil); err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) + } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) + } + } + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0d13e1efd..b1e5f1b24 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5214,6 +5214,8 @@ func TestListenBacklogFull(t *testing.T) { func TestListenNoAcceptNonUnicastV4(t *testing.T) { multicastAddr := tcpip.Address("\xe0\x00\x01\x02") otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + subnet := context.StackAddrWithPrefix.Subnet() + subnetBroadcastAddr := subnet.Broadcast() tests := []struct { name string @@ -5221,53 +5223,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { dstAddr tcpip.Address }{ { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, + name: "SourceUnspecified", + srcAddr: header.IPv4Any, + dstAddr: context.StackAddr, }, { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, + name: "SourceBroadcast", + srcAddr: header.IPv4Broadcast, + dstAddr: context.StackAddr, }, { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, + name: "SourceOurMulticast", + srcAddr: multicastAddr, + dstAddr: context.StackAddr, }, { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, + name: "SourceOtherMulticast", + srcAddr: otherMulticastAddr, + dstAddr: context.StackAddr, }, { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, + name: "DestUnspecified", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Any, }, { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, + name: "DestBroadcast", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Broadcast, }, { - "DestOurMulticast", - context.TestAddr, - multicastAddr, + name: "DestOurMulticast", + srcAddr: context.TestAddr, + dstAddr: multicastAddr, }, { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, + name: "DestOtherMulticast", + srcAddr: context.TestAddr, + dstAddr: otherMulticastAddr, + }, + { + name: "SrcSubnetBroadcast", + srcAddr: subnetBroadcastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestSubnetBroadcast", + srcAddr: context.TestAddr, + dstAddr: subnetBroadcastAddr, }, } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5367,11 +5375,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index baf7df197..85e8c1c75 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -53,11 +53,11 @@ const ( TestPort = 4096 // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" // TestV6Addr is the source address for packets sent to the stack via // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // StackV4MappedAddr is StackAddr as a mapped v6 address. StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr @@ -73,6 +73,18 @@ const ( testInitialSequenceNumber = 789 ) +// StackAddrWithPrefix is StackAddr with its associated prefix length. +var StackAddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackAddr, + PrefixLen: 24, +} + +// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. +var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackV6Addr, + PrefixLen: header.IIDOffsetInIPv6Address * 8, +} + // Headers is used to represent the TCP header fields when building a // new packet. type Headers struct { @@ -184,12 +196,20 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ |