diff options
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 116 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 66 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 261 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 148 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/route_test.go | 132 | ||||
-rw-r--r-- | pkg/tcpip/tests/utils/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/tests/utils/utils.go | 60 |
8 files changed, 560 insertions, 228 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 8a2140ebe..74270f5f1 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -593,7 +593,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -634,12 +634,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if e.protocol.options.DropExternalLoopbackTraffic { + if header.IsV4LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV4LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -662,62 +675,32 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats - h := header.IPv4(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.ip.MalformedPacketsReceived.Increment() - return - } - - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { - stats.ip.MalformedPacketsReceived.Increment() - return - } - srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1114,13 +1097,46 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv4(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false + } + + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { + return nil, false } if hasTransportHdr { @@ -1134,7 +1150,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -1213,6 +1229,10 @@ func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolN type Options struct { // IGMP holds options for IGMP. IGMP IGMPOptions + + // DropExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be dropped. + DropExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv4 network protocol. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 350493958..b94cb428f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -884,9 +884,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { - ep.handlePacket(pkt) + ep.handleValidatedPacket(h, pkt) return nil } @@ -925,12 +924,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if !e.protocol.parse(pkt) { + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { stats.MalformedPacketsReceived.Increment() return } if !e.nic.IsLoopback() { + if e.protocol.options.DropExternalLoopbackTraffic { + if header.IsV6LoopbackAddress(h.SourceAddress()) { + stats.InvalidSourceAddressesReceived.Increment() + return + } + + if header.IsV6LoopbackAddress(h.DestinationAddress()) { + stats.InvalidDestinationAddressesReceived.Increment() + return + } + } + if e.protocol.stack.HandleLocal() { addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) if addressEndpoint != nil { @@ -953,35 +965,31 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handlePacket(pkt) + e.handleValidatedPacket(h, pkt) } +// handleLocalPacket is like HandlePacket except it does not perform the +// prerouting iptables hook or check for loopback traffic that originated from +// outside of the netstack (i.e. martian loopback packets). func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { stats := e.stats.ip - stats.PacketsReceived.Increment() pkt = pkt.CloneToInbound() - if e.protocol.parse(pkt) { - pkt.RXTransportChecksumValidated = canSkipRXChecksum - e.handlePacket(pkt) + pkt.RXTransportChecksumValidated = canSkipRXChecksum + + h, ok := e.protocol.parseAndValidate(pkt) + if !ok { + stats.MalformedPacketsReceived.Increment() return } - stats.MalformedPacketsReceived.Increment() + e.handleValidatedPacket(h, pkt) } -// handlePacket is like HandlePacket except it does not perform the prerouting -// iptables hook. -func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip - - h := header.IPv6(pkt.NetworkHeader().View()) - if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.MalformedPacketsReceived.Increment() - return - } srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1920,13 +1928,21 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// parse is like Parse but also attempts to parse the transport layer. +// parseAndValidate parses the packet (including its transport layer header) and +// returns the parsed IP header. // -// Returns true if the network header was successfully parsed. -func (p *protocol) parse(pkt *stack.PacketBuffer) bool { +// Returns true if the IP header was successfully parsed. +func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) { transProtoNum, hasTransportHdr, ok := p.Parse(pkt) if !ok { - return false + return nil, false + } + + h := header.IPv6(pkt.NetworkHeader().View()) + // Do not include the link header's size when calculating the size of the IP + // packet. + if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) { + return nil, false } if hasTransportHdr { @@ -1940,7 +1956,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool { } } - return true + return h, true } // Parse implements stack.NetworkProtocol.Parse. @@ -2054,6 +2070,10 @@ type Options struct { // DADConfigs holds the default DAD configurations used by IPv6 endpoints. DADConfigs stack.DADConfigurations + + // DropExternalLoopbackTraffic indicates that inbound loopback packets (i.e. + // martian loopback packets) should be dropped. + DropExternalLoopbackTraffic bool } // NewProtocolWithOptions returns an IPv6 network protocol. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 58aabe547..3cc8c36f1 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -72,11 +72,13 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 80afc2825..0a9ea1aa8 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -24,11 +24,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -502,3 +504,262 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { }) } } + +func TestExternalLoopbackTraffic(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + numPackets = 1 + ) + + loopbackSourcedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + } + + loopbackSourcedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + } + + loopbackDestinedICMPv4 := func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + } + + loopbackDestinedICMPv6 := func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + } + + invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidSourceAddressesReceived + } + + invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { + return s.InvalidDestinationAddressesReceived + } + + tests := []struct { + name string + dropExternalLoopback bool + forwarding bool + rxICMP func(*channel.Endpoint) + invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter + shouldAccept bool + }{ + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv4, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv4, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackSourcedICMPv6, + invalidAddressStat: invalidSrcAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: false, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled", + dropExternalLoopback: false, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: true, + }, + { + name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled", + dropExternalLoopback: true, + forwarding: true, + rxICMP: loopbackDestinedICMPv6, + invalidAddressStat: invalidDestAddrStat, + shouldAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + DropExternalLoopbackTraffic: test.dropExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + DropExternalLoopbackTraffic: test.dropExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err) + } + if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err) + } + + if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err) + } + + if test.forwarding { + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + tcpip.Route{ + Destination: ipv4Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + tcpip.Route{ + Destination: header.IPv6Loopback.WithPrefix().Subnet(), + NIC: nicID2, + }, + }) + + stats := s.Stats().IP + invalidAddressStat := test.invalidAddressStat(stats) + deliveredPacketsStat := stats.PacketsDelivered + if got := invalidAddressStat.Value(); got != 0 { + t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got) + } + if got := deliveredPacketsStat.Value(); got != 0 { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got) + } + test.rxICMP(e) + var expectedInvalidPackets uint64 + if !test.shouldAccept { + expectedInvalidPackets = numPackets + } + if got := invalidAddressStat.Value(); got != expectedInvalidPackets { + t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets) + } + if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want { + t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 29266a4fc..77f4a88ec 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -45,82 +45,61 @@ const ( func TestPingMulticastBroadcast(t *testing.T) { const nicID = 1 - rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(icmp.ProtocolNumber4), - TTL: ttl, - SrcAddr: utils.RemoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { - totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize - hdr := buffer.NewPrependable(totalLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) - pkt.SetCode(0) - pkt.SetChecksum(0) - pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: pkt, - Src: utils.RemoteIPv6Addr, - Dst: dst, - })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: ttl, - SrcAddr: utils.RemoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - tests := []struct { - name string - dstAddr tcpip.Address + name string + protoNum tcpip.NetworkProtocolNumber + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + srcAddr tcpip.Address + dstAddr tcpip.Address + expectedSrc tcpip.Address }{ { - name: "IPv4 unicast", - dstAddr: utils.Ipv4Addr.Address, + name: "IPv4 unicast", + protoNum: header.IPv4ProtocolNumber, + dstAddr: utils.Ipv4Addr.Address, + srcAddr: utils.RemoteIPv4Addr, + rxICMP: utils.RxICMPv4EchoRequest, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 directed broadcast", - dstAddr: utils.Ipv4SubnetBcast, + name: "IPv4 directed broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4SubnetBcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 broadcast", - dstAddr: header.IPv4Broadcast, + name: "IPv4 broadcast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4Broadcast, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv4 all-systems multicast", - dstAddr: header.IPv4AllSystems, + name: "IPv4 all-systems multicast", + protoNum: header.IPv4ProtocolNumber, + rxICMP: utils.RxICMPv4EchoRequest, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: header.IPv4AllSystems, + expectedSrc: utils.Ipv4Addr.Address, }, { - name: "IPv6 unicast", - dstAddr: utils.Ipv6Addr.Address, + name: "IPv6 unicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr.Address, + expectedSrc: utils.Ipv6Addr.Address, }, { - name: "IPv6 all-nodes multicast", - dstAddr: header.IPv6AllNodesMulticastAddress, + name: "IPv6 all-nodes multicast", + protoNum: header.IPv6ProtocolNumber, + rxICMP: utils.RxICMPv6EchoRequest, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: header.IPv6AllNodesMulticastAddress, + expectedSrc: utils.Ipv6Addr.Address, }, } @@ -157,44 +136,29 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - var rxICMP func(*channel.Endpoint, tcpip.Address) - var expectedSrc tcpip.Address - var expectedDst tcpip.Address - var protoNum tcpip.NetworkProtocolNumber - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - rxICMP = rxIPv4ICMP - expectedSrc = utils.Ipv4Addr.Address - expectedDst = utils.RemoteIPv4Addr - protoNum = header.IPv4ProtocolNumber - case header.IPv6AddressSize: - rxICMP = rxIPv6ICMP - expectedSrc = utils.Ipv6Addr.Address - expectedDst = utils.RemoteIPv6Addr - protoNum = header.IPv6ProtocolNumber - default: - t.Fatalf("got unexpected address length = %d bytes", l) - } - - rxICMP(e, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") } - if pkt.Route.LocalAddress != expectedSrc { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) + if pkt.Route.LocalAddress != test.expectedSrc { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc) } - if pkt.Route.RemoteAddress != expectedDst { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if pkt.Route.RemoteAddress != test.srcAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr) } - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if src != expectedSrc { - t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) + src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if src != test.expectedSrc { + t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc) } - if dst != expectedDst { - t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) + // The destination of the response packet should be the source of the + // original packet. + if dst != test.srcAddr { + t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr) } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 4455f6dd7..568a982bb 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -16,6 +16,7 @@ package route_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -161,78 +162,79 @@ func TestLocalPing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - e := test.linkEndpoint() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } + for _, dropExternalLoopback := range []bool{true, false} { + t.Run(fmt.Sprintf("DropExternalLoopback=%t", dropExternalLoopback), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + DropExternalLoopbackTraffic: dropExternalLoopback, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + DropExternalLoopbackTraffic: dropExternalLoopback, + }), + }, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - if len(test.localAddr) != 0 { - if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) - } - } + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - connAddr := tcpip.FullAddress{Addr: test.localAddr} - { - err := ep.Connect(connAddr) - if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) - } - } + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() - if test.expectedConnectErr != nil { - return - } + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } - payload := test.icmpBuf(t) - var r bytes.Reader - r.Reset(payload) - var wOpts tcpip.WriteOptions - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) - } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) - } + if test.expectedConnectErr != nil { + return + } - // Wait for the endpoint to become readable. - <-ch + var r bytes.Reader + payload := test.icmpBuf(t) + r.Reset(payload) + var wOpts tcpip.WriteOptions + if n, err := ep.Write(&r, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } + // Wait for the endpoint to become readable. + <-ch - test.checkLinkEndpoint(t, e) + var w bytes.Buffer + rr, err := ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if err != nil { + t.Fatalf("ep.Read(...): %s", err) + } + if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if rr.RemoteAddr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } }) } } diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD index 433004148..a9699a367 100644 --- a/pkg/tcpip/tests/utils/BUILD +++ b/pkg/tcpip/tests/utils/BUILD @@ -8,12 +8,15 @@ go_library( visibility = ["//pkg/tcpip/tests:__subpackages__"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/nested", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", ], ) diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index f414a2234..d1c9f3a94 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -20,13 +20,16 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) // Common NIC IDs used by tests. @@ -45,6 +48,10 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) +const ( + ttl = 255 +) + // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -312,3 +319,56 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. }, }) } + +// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on +// the provided endpoint. +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(header.ICMPv4UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ttl, + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + +// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// the provided endpoint. +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(header.ICMPv6UnusedCode) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: src, + Dst: dst, + })) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} |