summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go116
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go66
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go261
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go148
-rw-r--r--pkg/tcpip/tests/integration/route_test.go132
-rw-r--r--pkg/tcpip/tests/utils/BUILD3
-rw-r--r--pkg/tcpip/tests/utils/utils.go60
8 files changed, 560 insertions, 228 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 8a2140ebe..74270f5f1 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -593,7 +593,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
- ep.handlePacket(pkt)
+ ep.handleValidatedPacket(h, pkt)
return nil
}
@@ -634,12 +634,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if !e.protocol.parse(pkt) {
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
stats.MalformedPacketsReceived.Increment()
return
}
if !e.nic.IsLoopback() {
+ if e.protocol.options.DropExternalLoopbackTraffic {
+ if header.IsV4LoopbackAddress(h.SourceAddress()) {
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
+ if header.IsV4LoopbackAddress(h.DestinationAddress()) {
+ stats.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+ }
+
if e.protocol.stack.HandleLocal() {
addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
if addressEndpoint != nil {
@@ -662,62 +675,32 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handlePacket(pkt)
+ e.handleValidatedPacket(h, pkt)
}
+// handleLocalPacket is like HandlePacket except it does not perform the
+// prerouting iptables hook or check for loopback traffic that originated from
+// outside of the netstack (i.e. martian loopback packets).
func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
stats := e.stats.ip
-
stats.PacketsReceived.Increment()
pkt = pkt.CloneToInbound()
- if e.protocol.parse(pkt) {
- pkt.RXTransportChecksumValidated = canSkipRXChecksum
- e.handlePacket(pkt)
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
+ stats.MalformedPacketsReceived.Increment()
return
}
- stats.MalformedPacketsReceived.Increment()
+ e.handleValidatedPacket(h, pkt)
}
-// handlePacket is like HandlePacket except it does not perform the prerouting
-// iptables hook.
-func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats
- h := header.IPv4(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.ip.MalformedPacketsReceived.Increment()
- return
- }
-
- // There has been some confusion regarding verifying checksums. We need
- // just look for negative 0 (0xffff) as the checksum, as it's not possible to
- // get positive 0 (0) for the checksum. Some bad implementations could get it
- // when doing entry replacement in the early days of the Internet,
- // however the lore that one needs to check for both persists.
- //
- // RFC 1624 section 1 describes the source of this confusion as:
- // [the partial recalculation method described in RFC 1071] computes a
- // result for certain cases that differs from the one obtained from
- // scratch (one's complement of one's complement sum of the original
- // fields).
- //
- // However RFC 1624 section 5 clarifies that if using the verification method
- // "recommended by RFC 1071, it does not matter if an intermediate system
- // generated a -0 instead of +0".
- //
- // RFC1071 page 1 specifies the verification method as:
- // (3) To check a checksum, the 1's complement sum is computed over the
- // same set of octets, including the checksum field. If the result
- // is all 1 bits (-0 in 1's complement arithmetic), the check
- // succeeds.
- if h.CalculateChecksum() != 0xffff {
- stats.ip.MalformedPacketsReceived.Increment()
- return
- }
-
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1114,13 +1097,46 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// parse is like Parse but also attempts to parse the transport layer.
+// parseAndValidate parses the packet (including its transport layer header) and
+// returns the parsed IP header.
//
-// Returns true if the network header was successfully parsed.
-func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+// Returns true if the IP header was successfully parsed.
+func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) {
transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
if !ok {
- return false
+ return nil, false
+ }
+
+ h := header.IPv4(pkt.NetworkHeader().View())
+ // Do not include the link header's size when calculating the size of the IP
+ // packet.
+ if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) {
+ return nil, false
+ }
+
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if h.CalculateChecksum() != 0xffff {
+ return nil, false
}
if hasTransportHdr {
@@ -1134,7 +1150,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
}
}
- return true
+ return h, true
}
// Parse implements stack.NetworkProtocol.Parse.
@@ -1213,6 +1229,10 @@ func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolN
type Options struct {
// IGMP holds options for IGMP.
IGMP IGMPOptions
+
+ // DropExternalLoopbackTraffic indicates that inbound loopback packets (i.e.
+ // martian loopback packets) should be dropped.
+ DropExternalLoopbackTraffic bool
}
// NewProtocolWithOptions returns an IPv4 network protocol.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 350493958..b94cb428f 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -884,9 +884,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
-
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
- ep.handlePacket(pkt)
+ ep.handleValidatedPacket(h, pkt)
return nil
}
@@ -925,12 +924,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- if !e.protocol.parse(pkt) {
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
stats.MalformedPacketsReceived.Increment()
return
}
if !e.nic.IsLoopback() {
+ if e.protocol.options.DropExternalLoopbackTraffic {
+ if header.IsV6LoopbackAddress(h.SourceAddress()) {
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
+ if header.IsV6LoopbackAddress(h.DestinationAddress()) {
+ stats.InvalidDestinationAddressesReceived.Increment()
+ return
+ }
+ }
+
if e.protocol.stack.HandleLocal() {
addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
if addressEndpoint != nil {
@@ -953,35 +965,31 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handlePacket(pkt)
+ e.handleValidatedPacket(h, pkt)
}
+// handleLocalPacket is like HandlePacket except it does not perform the
+// prerouting iptables hook or check for loopback traffic that originated from
+// outside of the netstack (i.e. martian loopback packets).
func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
stats := e.stats.ip
-
stats.PacketsReceived.Increment()
pkt = pkt.CloneToInbound()
- if e.protocol.parse(pkt) {
- pkt.RXTransportChecksumValidated = canSkipRXChecksum
- e.handlePacket(pkt)
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+
+ h, ok := e.protocol.parseAndValidate(pkt)
+ if !ok {
+ stats.MalformedPacketsReceived.Increment()
return
}
- stats.MalformedPacketsReceived.Increment()
+ e.handleValidatedPacket(h, pkt)
}
-// handlePacket is like HandlePacket except it does not perform the prerouting
-// iptables hook.
-func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
-
- h := header.IPv6(pkt.NetworkHeader().View())
- if !h.IsValid(pkt.Data().Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.MalformedPacketsReceived.Increment()
- return
- }
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1920,13 +1928,21 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// parse is like Parse but also attempts to parse the transport layer.
+// parseAndValidate parses the packet (including its transport layer header) and
+// returns the parsed IP header.
//
-// Returns true if the network header was successfully parsed.
-func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+// Returns true if the IP header was successfully parsed.
+func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) {
transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
if !ok {
- return false
+ return nil, false
+ }
+
+ h := header.IPv6(pkt.NetworkHeader().View())
+ // Do not include the link header's size when calculating the size of the IP
+ // packet.
+ if !h.IsValid(pkt.Size() - pkt.LinkHeader().View().Size()) {
+ return nil, false
}
if hasTransportHdr {
@@ -1940,7 +1956,7 @@ func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
}
}
- return true
+ return h, true
}
// Parse implements stack.NetworkProtocol.Parse.
@@ -2054,6 +2070,10 @@ type Options struct {
// DADConfigs holds the default DAD configurations used by IPv6 endpoints.
DADConfigs stack.DADConfigurations
+
+ // DropExternalLoopbackTraffic indicates that inbound loopback packets (i.e.
+ // martian loopback packets) should be dropped.
+ DropExternalLoopbackTraffic bool
}
// NewProtocolWithOptions returns an IPv6 network protocol.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 58aabe547..3cc8c36f1 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -72,11 +72,13 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 80afc2825..0a9ea1aa8 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -24,11 +24,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -502,3 +504,262 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
})
}
}
+
+func TestExternalLoopbackTraffic(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+
+ numPackets = 1
+ )
+
+ loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address)
+ }
+
+ loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address)
+ }
+
+ loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback)
+ }
+
+ loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback)
+ }
+
+ invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
+ return s.InvalidSourceAddressesReceived
+ }
+
+ invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
+ return s.InvalidDestinationAddressesReceived
+ }
+
+ tests := []struct {
+ name string
+ dropExternalLoopback bool
+ forwarding bool
+ rxICMP func(*channel.Endpoint)
+ invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter
+ shouldAccept bool
+ }{
+ {
+ name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv4,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv4,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+
+ {
+ name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackSourcedICMPv6,
+ invalidAddressStat: invalidSrcAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: false,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ {
+ name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled",
+ dropExternalLoopback: false,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: true,
+ },
+ {
+ name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled",
+ dropExternalLoopback: true,
+ forwarding: true,
+ rxICMP: loopbackDestinedICMPv6,
+ invalidAddressStat: invalidDestAddrStat,
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ DropExternalLoopbackTraffic: test.dropExternalLoopback,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ DropExternalLoopbackTraffic: test.dropExternalLoopback,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ }
+ if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ }
+
+ if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ }
+
+ if test.forwarding {
+ if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ tcpip.Route{
+ Destination: ipv4Loopback.WithPrefix().Subnet(),
+ NIC: nicID2,
+ },
+ tcpip.Route{
+ Destination: header.IPv6Loopback.WithPrefix().Subnet(),
+ NIC: nicID2,
+ },
+ })
+
+ stats := s.Stats().IP
+ invalidAddressStat := test.invalidAddressStat(stats)
+ deliveredPacketsStat := stats.PacketsDelivered
+ if got := invalidAddressStat.Value(); got != 0 {
+ t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got)
+ }
+ if got := deliveredPacketsStat.Value(); got != 0 {
+ t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got)
+ }
+ test.rxICMP(e)
+ var expectedInvalidPackets uint64
+ if !test.shouldAccept {
+ expectedInvalidPackets = numPackets
+ }
+ if got := invalidAddressStat.Value(); got != expectedInvalidPackets {
+ t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets)
+ }
+ if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want {
+ t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 29266a4fc..77f4a88ec 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -45,82 +45,61 @@ const (
func TestPingMulticastBroadcast(t *testing.T) {
const nicID = 1
- rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ttl,
- SrcAddr: utils.RemoteIPv4Addr,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: utils.RemoteIPv6Addr,
- Dst: dst,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: ttl,
- SrcAddr: utils.RemoteIPv6Addr,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
tests := []struct {
- name string
- dstAddr tcpip.Address
+ name string
+ protoNum tcpip.NetworkProtocolNumber
+ rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectedSrc tcpip.Address
}{
{
- name: "IPv4 unicast",
- dstAddr: utils.Ipv4Addr.Address,
+ name: "IPv4 unicast",
+ protoNum: header.IPv4ProtocolNumber,
+ dstAddr: utils.Ipv4Addr.Address,
+ srcAddr: utils.RemoteIPv4Addr,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 directed broadcast",
- dstAddr: utils.Ipv4SubnetBcast,
+ name: "IPv4 directed broadcast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4SubnetBcast,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 broadcast",
- dstAddr: header.IPv4Broadcast,
+ name: "IPv4 broadcast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: header.IPv4Broadcast,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv4 all-systems multicast",
- dstAddr: header.IPv4AllSystems,
+ name: "IPv4 all-systems multicast",
+ protoNum: header.IPv4ProtocolNumber,
+ rxICMP: utils.RxICMPv4EchoRequest,
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: header.IPv4AllSystems,
+ expectedSrc: utils.Ipv4Addr.Address,
},
{
- name: "IPv6 unicast",
- dstAddr: utils.Ipv6Addr.Address,
+ name: "IPv6 unicast",
+ protoNum: header.IPv6ProtocolNumber,
+ rxICMP: utils.RxICMPv6EchoRequest,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr.Address,
+ expectedSrc: utils.Ipv6Addr.Address,
},
{
- name: "IPv6 all-nodes multicast",
- dstAddr: header.IPv6AllNodesMulticastAddress,
+ name: "IPv6 all-nodes multicast",
+ protoNum: header.IPv6ProtocolNumber,
+ rxICMP: utils.RxICMPv6EchoRequest,
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: header.IPv6AllNodesMulticastAddress,
+ expectedSrc: utils.Ipv6Addr.Address,
},
}
@@ -157,44 +136,29 @@ func TestPingMulticastBroadcast(t *testing.T) {
},
})
- var rxICMP func(*channel.Endpoint, tcpip.Address)
- var expectedSrc tcpip.Address
- var expectedDst tcpip.Address
- var protoNum tcpip.NetworkProtocolNumber
- switch l := len(test.dstAddr); l {
- case header.IPv4AddressSize:
- rxICMP = rxIPv4ICMP
- expectedSrc = utils.Ipv4Addr.Address
- expectedDst = utils.RemoteIPv4Addr
- protoNum = header.IPv4ProtocolNumber
- case header.IPv6AddressSize:
- rxICMP = rxIPv6ICMP
- expectedSrc = utils.Ipv6Addr.Address
- expectedDst = utils.RemoteIPv6Addr
- protoNum = header.IPv6ProtocolNumber
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
- }
-
- rxICMP(e, test.dstAddr)
+ test.rxICMP(e, test.srcAddr, test.dstAddr)
pkt, ok := e.Read()
if !ok {
t.Fatal("expected ICMP response")
}
- if pkt.Route.LocalAddress != expectedSrc {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc)
+ if pkt.Route.LocalAddress != test.expectedSrc {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc)
}
- if pkt.Route.RemoteAddress != expectedDst {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
+ // The destination of the response packet should be the source of the
+ // original packet.
+ if pkt.Route.RemoteAddress != test.srcAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr)
}
- src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if src != expectedSrc {
- t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
+ src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if src != test.expectedSrc {
+ t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc)
}
- if dst != expectedDst {
- t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst)
+ // The destination of the response packet should be the source of the
+ // original packet.
+ if dst != test.srcAddr {
+ t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr)
}
})
}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 4455f6dd7..568a982bb 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -16,6 +16,7 @@ package route_test
import (
"bytes"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -161,78 +162,79 @@ func TestLocalPing(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- HandleLocal: true,
- })
- e := test.linkEndpoint()
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
+ for _, dropExternalLoopback := range []bool{true, false} {
+ t.Run(fmt.Sprintf("DropExternalLoopback=%t", dropExternalLoopback), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ DropExternalLoopbackTraffic: dropExternalLoopback,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ DropExternalLoopbackTraffic: dropExternalLoopback,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ HandleLocal: true,
+ })
+ e := test.linkEndpoint()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
- }
- }
+ if len(test.localAddr) != 0 {
+ if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ }
+ }
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
- }
- defer ep.Close()
-
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
- {
- err := ep.Connect(connAddr)
- if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff)
- }
- }
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
- if test.expectedConnectErr != nil {
- return
- }
+ connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ if err := ep.Connect(connAddr); err != test.expectedConnectErr {
+ t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ }
- payload := test.icmpBuf(t)
- var r bytes.Reader
- r.Reset(payload)
- var wOpts tcpip.WriteOptions
- if n, err := ep.Write(&r, wOpts); err != nil {
- t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
- } else if n != int64(len(payload)) {
- t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
- }
+ if test.expectedConnectErr != nil {
+ return
+ }
- // Wait for the endpoint to become readable.
- <-ch
+ var r bytes.Reader
+ payload := test.icmpBuf(t)
+ r.Reset(payload)
+ var wOpts tcpip.WriteOptions
+ if n, err := ep.Write(&r, wOpts); err != nil {
+ t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
+ } else if n != int64(len(payload)) {
+ t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
+ }
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %#v): %s", opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: test.localAddr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
+ // Wait for the endpoint to become readable.
+ <-ch
- test.checkLinkEndpoint(t, e)
+ var w bytes.Buffer
+ rr, err := ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if err != nil {
+ t.Fatalf("ep.Read(...): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if rr.RemoteAddr.Addr != test.localAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ }
+
+ test.checkLinkEndpoint(t, e)
+ })
+ }
})
}
}
diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD
index 433004148..a9699a367 100644
--- a/pkg/tcpip/tests/utils/BUILD
+++ b/pkg/tcpip/tests/utils/BUILD
@@ -8,12 +8,15 @@ go_library(
visibility = ["//pkg/tcpip/tests:__subpackages__"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/nested",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
],
)
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index f414a2234..d1c9f3a94 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -20,13 +20,16 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
// Common NIC IDs used by tests.
@@ -45,6 +48,10 @@ const (
LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
)
+const (
+ ttl = 255
+)
+
// Common IP addresses used by tests.
var (
Ipv4Addr = tcpip.AddressWithPrefix{
@@ -312,3 +319,56 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
},
})
}
+
+// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
+// the provided endpoint.
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(header.ICMPv4UnusedCode)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}
+
+// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
+// the provided endpoint.
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(header.ICMPv6UnusedCode)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: pkt,
+ Src: src,
+ Dst: dst,
+ }))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: ttl,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+}