From 4065604e1b6b767754a8ce939add6fdd91616a24 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 17 Mar 2021 11:10:04 -0700 Subject: Drop loopback traffic from outside of the stack Loopback traffic should be stack-local but gVisor has some clients that depend on the ability to receive loopback traffic that originated from outside of the stack. Because of this, we guard this change behind IP protocol options. Test: integration_test.TestExternalLoopbackTraffic PiperOrigin-RevId: 363461242 --- pkg/tcpip/network/ipv4/ipv4.go | 116 +++++---- pkg/tcpip/network/ipv6/ipv6.go | 66 ++++-- pkg/tcpip/tests/integration/BUILD | 2 + pkg/tcpip/tests/integration/loopback_test.go | 261 +++++++++++++++++++++ .../tests/integration/multicast_broadcast_test.go | 148 +++++------- pkg/tcpip/tests/integration/route_test.go | 132 ++++++----- pkg/tcpip/tests/utils/BUILD | 3 + pkg/tcpip/tests/utils/utils.go | 60 +++++ 8 files changed, 560 insertions(+), 228 deletions(-) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 8a2140ebe..74270f5f1 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -593,7 +593,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -634,12 +634,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if e.protocol.options.DropExternalLoopbackTraffic { + if header.IsV4LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV4LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -662,62 +675,32 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats - h := header.IPv4(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.ip.MalformedPacketsReceived.Increment() - return - } - - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { - stats.ip.MalformedPacketsReceived.Increment() - return - } - srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1114,13 +1097,46 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv4(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false + } + + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { + return nil, false } if hasTransportHdr { @@ -1134,7 +1150,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -1213,6 +1229,10 @@ func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolN type Options struct { // IGMP holds options for IGMP. IGMP IGMPOptions + + // DropExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be dropped. + DropExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv4 network protocol. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 350493958..b94cb428f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -884,9 +884,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -925,12 +924,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if e.protocol.options.DropExternalLoopbackTraffic { + if header.IsV6LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV6LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -953,35 +965,31 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip - - h := header.IPv6(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.MalformedPacketsReceived.Increment() - return - } srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1920,13 +1928,21 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv6(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false } if hasTransportHdr { @@ -1940,7 +1956,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -2054,6 +2070,10 @@ type Options struct { // DADConfigs holds the default DAD configurations used by IPv6 endpoints. DADConfigs stack.DADConfigurations + + // DropExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be dropped. + DropExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv6 network protocol. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 58aabe547..3cc8c36f1 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -72,11 +72,13 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 80afc2825..0a9ea1aa8 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -24,11 +24,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -502,3 +504,262 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { }) } } + +func TestExternalLoopbackTraffic(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + numPackets = 1 + ) + + loopbackSourcedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + } + + loopbackSourcedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + } + + loopbackDestinedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + } + + loopbackDestinedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + } + + invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidSourceAddressesReceived + } + + invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidDestinationAddressesReceived + } + + tests := []struct { + name string + dropExternalLoopback bool + forwarding bool + rxICMP func(*channel.Endpoint) + invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter + shouldAccept bool + }{ + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + DropExternalLoopbackTraffic: test.dropExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + DropExternalLoopbackTraffic: test.dropExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + } + + if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + } + + if test.forwarding { + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: ipv4Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + tcpip.Route{ + Destination: header.IPv6Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + }) + + stats := s.Stats().IP + invalidAddressStat := test.invalidAddressStat(stats) + deliveredPacketsStat := stats.PacketsDelivered + if got := invalidAddressStat.Value(); got != 0 { + t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got) + } + if got := deliveredPacketsStat.Value(); got != 0 { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got) + } + test.rxICMP(e) + var expectedInvalidPackets uint64 + if !test.shouldAccept { + expectedInvalidPackets = numPackets + } + if got := invalidAddressStat.Value(); got != expectedInvalidPackets { + t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets) + } + if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 29266a4fc..77f4a88ec 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -45,82 +45,61 @@ const ( func TestPingMulticastBroadcast(t *testing.T) { const nicID = 1 - rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ttl, - SrcAddr: utils.RemoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: utils.RemoteIPv6Addr, - Dst: dst, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ttl, - SrcAddr: utils.RemoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - tests := []struct { - name string - dstAddr tcpip.Address + name string + protoNum tcpip.NetworkProtocolNumber + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + srcAddr tcpip.Address + dstAddr tcpip.Address + expectedSrc tcpip.Address }{ { - name: "IPv4 unicast", - dstAddr: utils.Ipv4Addr.Address, + name: "IPv4 unicast", + protoNum: header.IPv4ProtocolNumber, + dstAddr: utils.Ipv4Addr.Address, + srcAddr: utils.RemoteIPv4Addr, + rxICMP: utils.RxICMPv4EchoRequest, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 directed broadcast", - dstAddr: utils.Ipv4SubnetBcast, + name: "IPv4 directed broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4SubnetBcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 broadcast", - dstAddr: header.IPv4Broadcast, + name: "IPv4 broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4Broadcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 all-systems multicast", - dstAddr: header.IPv4AllSystems, + name: "IPv4 all-systems multicast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4AllSystems, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv6 unicast", - dstAddr: utils.Ipv6Addr.Address, + name: "IPv6 unicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr.Address, + expectedSrc: utils.Ipv6Addr.Address, }, { - name: "IPv6 all-nodes multicast", - dstAddr: header.IPv6AllNodesMulticastAddress, + name: "IPv6 all-nodes multicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: header.IPv6AllNodesMulticastAddress, + expectedSrc: utils.Ipv6Addr.Address, }, } @@ -157,44 +136,29 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - var rxICMP func(*channel.Endpoint, tcpip.Address) - var expectedSrc tcpip.Address - var expectedDst tcpip.Address - var protoNum tcpip.NetworkProtocolNumber - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - rxICMP = rxIPv4ICMP - expectedSrc = utils.Ipv4Addr.Address - expectedDst = utils.RemoteIPv4Addr - protoNum = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - rxICMP = rxIPv6ICMP - expectedSrc = utils.Ipv6Addr.Address - expectedDst = utils.RemoteIPv6Addr - protoNum = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - rxICMP(e, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") } - if pkt.Route.LocalAddress != expectedSrc { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) + if pkt.Route.LocalAddress != test.expectedSrc { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc) } - if pkt.Route.RemoteAddress != expectedDst { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if pkt.Route.RemoteAddress != test.srcAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr) } - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if src != expectedSrc { - t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) + src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if src != test.expectedSrc { + t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc) } - if dst != expectedDst { - t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if dst != test.srcAddr { + t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr) } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 4455f6dd7..568a982bb 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -16,6 +16,7 @@ package route_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -161,78 +162,79 @@ func TestLocalPing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - e := test.linkEndpoint() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } + for _, dropExternalLoopback := range []bool{true, false} { + t.Run(fmt.Sprintf("DropExternalLoopback=%t", dropExternalLoopback), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + DropExternalLoopbackTraffic: dropExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + DropExternalLoopbackTraffic: dropExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) - } - } + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - connAddr := tcpip.FullAddress{Addr: test.localAddr} - { - err := ep.Connect(connAddr) - if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) - } - } + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() - if test.expectedConnectErr != nil { - return - } + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } - payload := test.icmpBuf(t) - var r bytes.Reader - r.Reset(payload) - var wOpts tcpip.WriteOptions - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) - } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) - } + if test.expectedConnectErr != nil { + return + } - // Wait for the endpoint to become readable. - <-ch + var r bytes.Reader + payload := test.icmpBuf(t) + r.Reset(payload) + var wOpts tcpip.WriteOptions + if n, err := ep.Write(&r, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } + // Wait for the endpoint to become readable. + <-ch - test.checkLinkEndpoint(t, e) + var w bytes.Buffer + rr, err := ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if err != nil { + t.Fatalf("ep.Read(...): %s", err) + } + if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if rr.RemoteAddr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } }) } } diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD index 433004148..a9699a367 100644 --- a/pkg/tcpip/tests/utils/BUILD +++ b/pkg/tcpip/tests/utils/BUILD @@ -8,12 +8,15 @@ go_library( visibility = ["//pkg/tcpip/tests:__subpackages__"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/nested", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", ], ) diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index f414a2234..d1c9f3a94 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -20,13 +20,16 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) // Common NIC IDs used by tests. @@ -45,6 +48,10 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) +const ( + ttl = 255 +) + // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -312,3 +319,56 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. }, }) } + +// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on +// the provided endpoint. +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(header.ICMPv4UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ttl, + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + +// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// the provided endpoint. +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(header.ICMPv6UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: dst, + })) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} -- cgit v1.2.3