diff options
author | gVisor bot <gvisor-bot@google.com> | 2020-07-09 22:33:53 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-07-09 22:35:42 -0700 |
commit | 5df3a8fedef7e54550d4c6b7172e25216600ee9f (patch) | |
tree | aa2cf2c00315018049f4f09b0e966678b3804337 /pkg | |
parent | 5946f111827fa4e342a2e6e9c043c198d2e5cb03 (diff) |
Discard multicast UDP source address.
RFC-1122 (and others) specify that UDP should not receive
datagrams that have a source address that is a multicast address.
Packets should never be received FROM a multicast address.
See also, RFC 768: 'User Datagram Protocol'
J. Postel, ISI, 28 August 1980
A UDP datagram received with an invalid IP source address
(e.g., a broadcast or multicast address) must be discarded
by UDP or by the IP layer (see rfc 1122 Section 3.2.1.3).
This CL does not address TCP or broadcast which is more complicated.
Also adds a test for both ipv6 and ipv4 UDP.
Fixes #3154
PiperOrigin-RevId: 320547674
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 104 |
4 files changed, 103 insertions, 18 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3b248a953..5a3cedd7c 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -192,6 +192,7 @@ var Metrics = tcpip.Stats{ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 25534a10d..cf7291d09 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -782,7 +782,7 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// buffer moderation. +// ModerateReceiveBufferOption is used by buffer moderation. type ModerateReceiveBufferOption bool // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the @@ -1244,6 +1244,9 @@ type UDPStats struct { // ChecksumErrors is the number of datagrams dropped due to bad checksums. ChecksumErrors *StatCounter + + // InvalidSourceAddress is the number of invalid sourced datagrams dropped. + InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0584ec8dc..4e9e114a9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1377,6 +1377,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + // Verify checksum unless RX checksum offload is enabled. // On IPv4, UDP checksum is optional, and a zero value means // the transmitter omitted the checksum generation (RFC768). @@ -1395,10 +1404,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } } - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 91ba031fa..90781cf49 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -872,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { |