summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/arp/BUILD2
-rw-r--r--pkg/tcpip/network/arp/stats_test.go2
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler.go3
-rw-r--r--pkg/tcpip/network/internal/ip/BUILD1
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go2
-rw-r--r--pkg/tcpip/network/internal/ip/errors.go85
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go1
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go115
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD5
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go68
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil_unsafe.go26
-rw-r--r--pkg/tcpip/network/ip_test.go8
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go59
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go211
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go403
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go159
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go5
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go363
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go448
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go130
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go57
-rw-r--r--pkg/tcpip/network/ipv6/stats.go4
25 files changed, 1562 insertions, 600 deletions
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 6905b9ccb..a72eb1aad 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -47,7 +47,7 @@ go_test(
library = ":arp",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index e867b3c3f..0df39ae81 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -19,8 +19,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ stack.NetworkInterface = (*testInterface)(nil)
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go
index 90075a70c..56b76a284 100644
--- a/pkg/tcpip/network/internal/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go
@@ -167,8 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s
resPkt := r.holes[0].pkt
for i := 1; i < len(r.holes); i++ {
- fragData := r.holes[i].pkt.Data()
- resPkt.Data().ReadFromData(fragData, fragData.Size())
+ stack.MergeFragment(resPkt, r.holes[i].pkt)
}
return resPkt, r.proto, true, memConsumed, nil
}
diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD
index d21b4c7ef..fd944ce99 100644
--- a/pkg/tcpip/network/internal/ip/BUILD
+++ b/pkg/tcpip/network/internal/ip/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "ip",
srcs = [
"duplicate_address_detection.go",
+ "errors.go",
"generic_multicast_protocol.go",
"stats.go",
],
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index eed49f5d2..5123b7d6a 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts
panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize))
}
+ configs.Validate()
+
*d = DAD{
opts: opts,
configs: configs,
diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go
new file mode 100644
index 000000000..94f1cd1cb
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/errors.go
@@ -0,0 +1,85 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// ForwardingError represents an error that occured while trying to forward
+// a packet.
+type ForwardingError interface {
+ isForwardingError()
+ fmt.Stringer
+}
+
+// ErrTTLExceeded indicates that the received packet's TTL has been exceeded.
+type ErrTTLExceeded struct{}
+
+func (*ErrTTLExceeded) isForwardingError() {}
+
+func (*ErrTTLExceeded) String() string { return "ttl exceeded" }
+
+// ErrParameterProblem indicates the received packet had a problem with an IP
+// parameter.
+type ErrParameterProblem struct{}
+
+func (*ErrParameterProblem) isForwardingError() {}
+
+func (*ErrParameterProblem) String() string { return "parameter problem" }
+
+// ErrLinkLocalSourceAddress indicates the received packet had a link-local
+// source address.
+type ErrLinkLocalSourceAddress struct{}
+
+func (*ErrLinkLocalSourceAddress) isForwardingError() {}
+
+func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" }
+
+// ErrLinkLocalDestinationAddress indicates the received packet had a link-local
+// destination address.
+type ErrLinkLocalDestinationAddress struct{}
+
+func (*ErrLinkLocalDestinationAddress) isForwardingError() {}
+
+func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" }
+
+// ErrNoRoute indicates that a route for the received packet couldn't be found.
+type ErrNoRoute struct{}
+
+func (*ErrNoRoute) isForwardingError() {}
+
+func (*ErrNoRoute) String() string { return "no route" }
+
+// ErrMessageTooLong indicates the packet was too big for the outgoing MTU.
+//
+// +stateify savable
+type ErrMessageTooLong struct{}
+
+func (*ErrMessageTooLong) isForwardingError() {}
+
+func (*ErrMessageTooLong) String() string { return "message too long" }
+
+// ErrOther indicates the packet coould not be forwarded for a reason
+// captured by the contained error.
+type ErrOther struct {
+ Err tcpip.Error
+}
+
+func (*ErrOther) isForwardingError() {}
+
+func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) }
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index ac35d81e7..d22974b12 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ip holds IPv4/IPv6 common utilities.
package ip
import (
diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index d06b26309..0c2b62127 100644
--- a/pkg/tcpip/network/internal/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
@@ -16,80 +16,145 @@ package ip
import "gvisor.dev/gvisor/pkg/tcpip"
+// LINT.IfChange(MultiCounterIPForwardingStats)
+
+// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter
+// may have several versions.
+type MultiCounterIPForwardingStats struct {
+ // Unrouteable is the number of IP packets received which were dropped
+ // because the netstack could not construct a route to their
+ // destination.
+ Unrouteable tcpip.MultiCounterStat
+
+ // ExhaustedTTL is the number of IP packets received which were dropped
+ // because their TTL was exhausted.
+ ExhaustedTTL tcpip.MultiCounterStat
+
+ // LinkLocalSource is the number of IP packets which were dropped
+ // because they contained a link-local source address.
+ LinkLocalSource tcpip.MultiCounterStat
+
+ // LinkLocalDestination is the number of IP packets which were dropped
+ // because they contained a link-local destination address.
+ LinkLocalDestination tcpip.MultiCounterStat
+
+ // PacketTooBig is the number of IP packets which were dropped because they
+ // were too big for the outgoing MTU.
+ PacketTooBig tcpip.MultiCounterStat
+
+ // ExtensionHeaderProblem is the number of IP packets which were dropped
+ // because of a problem encountered when processing an IPv6 extension
+ // header.
+ ExtensionHeaderProblem tcpip.MultiCounterStat
+
+ // Errors is the number of IP packets received which could not be
+ // successfully forwarded.
+ Errors tcpip.MultiCounterStat
+}
+
+// Init sets internal counters to track a and b counters.
+func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) {
+ m.Unrouteable.Init(a.Unrouteable, b.Unrouteable)
+ m.Errors.Init(a.Errors, b.Errors)
+ m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource)
+ m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination)
+ m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem)
+ m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig)
+ m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL)
+}
+
+// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats)
+
// LINT.IfChange(MultiCounterIPStats)
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
type MultiCounterIPStats struct {
- // PacketsReceived is the number of IP packets received from the link layer.
+ // PacketsReceived is the number of IP packets received from the link
+ // layer.
PacketsReceived tcpip.MultiCounterStat
- // DisabledPacketsReceived is the number of IP packets received from the link
- // layer when the IP layer is disabled.
+ // ValidPacketsReceived is the number of valid IP packets that reached the IP
+ // layer.
+ ValidPacketsReceived tcpip.MultiCounterStat
+
+ // DisabledPacketsReceived is the number of IP packets received from
+ // the link layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
- // InvalidDestinationAddressesReceived is the number of IP packets received
- // with an unknown or invalid destination address.
+ // InvalidDestinationAddressesReceived is the number of IP packets
+ // received with an unknown or invalid destination address.
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
- // InvalidSourceAddressesReceived is the number of IP packets received with a
- // source address that should never have been received on the wire.
+ // InvalidSourceAddressesReceived is the number of IP packets received
+ // with a source address that should never have been received on the
+ // wire.
InvalidSourceAddressesReceived tcpip.MultiCounterStat
- // PacketsDelivered is the number of incoming IP packets that are successfully
+ // PacketsDelivered is the number of incoming IP packets successfully
// delivered to the transport layer.
PacketsDelivered tcpip.MultiCounterStat
// PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent tcpip.MultiCounterStat
- // OutgoingPacketErrors is the number of IP packets which failed to write to a
- // link-layer endpoint.
+ // OutgoingPacketErrors is the number of IP packets which failed to
+ // write to a link-layer endpoint.
OutgoingPacketErrors tcpip.MultiCounterStat
- // MalformedPacketsReceived is the number of IP Packets that were dropped due
- // to the IP packet header failing validation checks.
+ // MalformedPacketsReceived is the number of IP Packets that were
+ // dropped due to the IP packet header failing validation checks.
MalformedPacketsReceived tcpip.MultiCounterStat
- // MalformedFragmentsReceived is the number of IP Fragments that were dropped
- // due to the fragment failing validation checks.
+ // MalformedFragmentsReceived is the number of IP Fragments that were
+ // dropped due to the fragment failing validation checks.
MalformedFragmentsReceived tcpip.MultiCounterStat
// IPTablesPreroutingDropped is the number of IP packets dropped in the
// Prerouting chain.
IPTablesPreroutingDropped tcpip.MultiCounterStat
- // IPTablesInputDropped is the number of IP packets dropped in the Input
- // chain.
+ // IPTablesInputDropped is the number of IP packets dropped in the
+ // Input chain.
IPTablesInputDropped tcpip.MultiCounterStat
- // IPTablesOutputDropped is the number of IP packets dropped in the Output
- // chain.
+ // IPTablesForwardDropped is the number of IP packets dropped in the
+ // Forward chain.
+ IPTablesForwardDropped tcpip.MultiCounterStat
+
+ // IPTablesOutputDropped is the number of IP packets dropped in the
+ // Output chain.
IPTablesOutputDropped tcpip.MultiCounterStat
- // IPTablesPostroutingDropped is the number of IP packets dropped in the
- // Postrouting chain.
+ // IPTablesPostroutingDropped is the number of IP packets dropped in
+ // the Postrouting chain.
IPTablesPostroutingDropped tcpip.MultiCounterStat
- // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
- // of IPStats.
+ // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option
+ // stats out of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.
OptionTimestampReceived tcpip.MultiCounterStat
- // OptionRecordRouteReceived is the number of Record Route options seen.
+ // OptionRecordRouteReceived is the number of Record Route options
+ // seen.
OptionRecordRouteReceived tcpip.MultiCounterStat
- // OptionRouterAlertReceived is the number of Router Alert options seen.
+ // OptionRouterAlertReceived is the number of Router Alert options
+ // seen.
OptionRouterAlertReceived tcpip.MultiCounterStat
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived tcpip.MultiCounterStat
+
+ // Forwarding collects stats related to IP forwarding.
+ Forwarding MultiCounterIPForwardingStats
}
// Init sets internal counters to track a and b counters.
func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived)
m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
@@ -100,12 +165,14 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
+ m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)
m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived)
+ m.Forwarding.Init(&a.Forwarding, &b.Forwarding)
}
// LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats)
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
index 1c4f583c7..cec3e62c4 100644
--- a/pkg/tcpip/network/internal/testutil/BUILD
+++ b/pkg/tcpip/network/internal/testutil/BUILD
@@ -4,10 +4,7 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
- srcs = [
- "testutil.go",
- "testutil_unsafe.go",
- ],
+ srcs = ["testutil.go"],
visibility = [
"//pkg/tcpip/network/arp:__pkg__",
"//pkg/tcpip/network/internal/fragmentation:__pkg__",
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index e2cf24b67..605e9ef8d 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -19,8 +19,6 @@ package testutil
import (
"fmt"
"math/rand"
- "reflect"
- "strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -129,69 +127,3 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi
}
return pkt
}
-
-func checkFieldCounts(ref, multi reflect.Value) error {
- refTypeName := ref.Type().Name()
- multiTypeName := multi.Type().Name()
- refNumField := ref.NumField()
- multiNumField := multi.NumField()
-
- if refNumField != multiNumField {
- return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
- }
-
- return nil
-}
-
-func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
- s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
- if !ok {
- return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
- }
-
- // The field names are expected to match (case insensitive).
- if !strings.EqualFold(refName, multiName) {
- return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
- }
-
- base := (*s).Value()
- m.Increment()
- if (*s).Value() != base+1 {
- return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
- }
-
- return nil
-}
-
-// ValidateMultiCounterStats verifies that every counter stored in multi is
-// correctly tracking its counterpart in the given counters.
-func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
- for _, c := range counters {
- if err := checkFieldCounts(c, multi); err != nil {
- return err
- }
- }
-
- for i := 0; i < multi.NumField(); i++ {
- multiName := multi.Type().Field(i).Name
- multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
-
- if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
- for _, c := range counters {
- if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
- return err
- }
- }
- } else {
- var countersNextField []reflect.Value
- for _, c := range counters {
- countersNextField = append(countersNextField, c.Field(i))
- }
- if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
deleted file mode 100644
index 5ff764800..000000000
--- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package testutil
-
-import (
- "reflect"
- "unsafe"
-)
-
-// unsafeExposeUnexportedFields takes a Value and returns a version of it in
-// which even unexported fields can be read and written.
-func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value {
- return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem()
-}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 74aad126c..bd63e0289 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
}
- if err := s.SetForwarding(test.netProto, true); err != nil {
- t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err)
+ if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err)
}
if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
@@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
}
- if err := s.SetForwarding(test.netProto, false); err != nil {
- t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err)
+ if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err)
}
if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7ee0495d9..c90974693 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -62,7 +62,7 @@ go_test(
library = ":ipv4",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index f663fdc0b..d1a82b584 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
return
}
- // Skip the ip header, then deliver the error.
- pkt.Data().TrimFront(hlen)
+ // Keep needed information before trimming header.
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt)
+ dstAddr := hdr.DestinationAddress()
+ // Skip the ip header, then deliver the error.
+ pkt.Data().DeleteFront(hlen)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
@@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4DstUnreachable:
received.dstUnreachable.Increment()
- pkt.Data().TrimFront(header.ICMPv4MinimumSize)
- switch h.Code() {
+ mtu := h.MTU()
+ code := h.Code()
+ pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
+ switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
case header.ICMPv4PortUnreachable:
e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt)
case header.ICMPv4FragmentationNeeded:
- networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize)
if err != nil {
networkMTU = 0
}
@@ -383,6 +387,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// icmpReason is a marker interface for IPv4 specific ICMP errors.
type icmpReason interface {
isICMPReason()
+ // isForwarding indicates whether or not the error arose while attempting to
+ // forward a packet.
isForwarding() bool
}
@@ -442,6 +448,39 @@ func (r *icmpReasonParamProblem) isForwarding() bool {
return r.forwarding
}
+// icmpReasonNetworkUnreachable is an error in which the network specified in
+// the internet destination field of the datagram is unreachable.
+type icmpReasonNetworkUnreachable struct{}
+
+func (*icmpReasonNetworkUnreachable) isICMPReason() {}
+func (*icmpReasonNetworkUnreachable) isForwarding() bool {
+ // If we hit a Net Unreachable error, then we know we are operating as
+ // a router. As per RFC 792 page 5, Destination Unreachable Message,
+ //
+ // If, according to the information in the gateway's routing tables,
+ // the network specified in the internet destination field of a
+ // datagram is unreachable, e.g., the distance to the network is
+ // infinity, the gateway may send a destination unreachable message to
+ // the internet source host of the datagram.
+ return true
+}
+
+// icmpReasonFragmentationNeeded is an error where a packet requires
+// fragmentation while also having the Don't Fragment flag set, as per RFC 792
+// page 3, Destination Unreachable Message.
+type icmpReasonFragmentationNeeded struct{}
+
+func (*icmpReasonFragmentationNeeded) isICMPReason() {}
+func (*icmpReasonFragmentationNeeded) isForwarding() bool {
+ // If we hit a Don't Fragment error, then we know we are operating as a router.
+ // As per RFC 792 page 4, Destination Unreachable Message,
+ //
+ // Another case is when a datagram must be fragmented to be forwarded by a
+ // gateway yet the Don't Fragment flag is on. In this case the gateway must
+ // discard the datagram and may return a destination unreachable message.
+ return true
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
@@ -610,6 +649,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonNetworkUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4NetUnreachable)
+ counter = sent.dstUnreachable
+ case *icmpReasonFragmentationNeeded:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
+ counter = sent.dstUnreachable
case *icmpReasonTTLExceeded:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4TTLExceeded)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a0bc06465..23178277a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -62,9 +63,15 @@ const (
fragmentblockSize = 8
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -81,6 +88,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -150,14 +163,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
e.mu.Lock()
defer e.mu.Unlock()
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
if forwarding {
// There does not seem to be an RFC requirement for a node to join the all
// routers multicast address but
@@ -433,6 +464,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
}
if packetMustBeFragmented(pkt, networkMTU) {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ if h.Flags()&header.IPv4FlagDontFragment != 0 && pkt.NetworkPacketInfo.IsForwardedPacket {
+ // TODO(gvisor.dev/issue/5919): Handle error condition in which DontFragment
+ // is set but the packet must be fragmented for the non-forwarding case.
+ return &tcpip.ErrMessageTooLong{}
+ }
sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
@@ -599,22 +636,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
h := header.IPv4(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
- if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) {
- // As per RFC 3927 section 7,
- //
- // A router MUST NOT forward a packet with an IPv4 Link-Local source or
- // destination address, irrespective of the router's default route
- // configuration or routes obtained from dynamic routing protocols.
- //
- // A router which receives a packet with an IPv4 Link-Local source or
- // destination address MUST NOT forward the packet. This prevents
- // forwarding of packets back onto the network segment from which they
- // originated, or to any other segment.
- return nil
+ // As per RFC 3927 section 7,
+ //
+ // A router MUST NOT forward a packet with an IPv4 Link-Local source or
+ // destination address, irrespective of the router's default route
+ // configuration or routes obtained from dynamic routing protocols.
+ //
+ // A router which receives a packet with an IPv4 Link-Local source or
+ // destination address MUST NOT forward the packet. This prevents
+ // forwarding of packets back onto the network segment from which they
+ // originated, or to any other segment.
+ if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) {
+ return &ip.ErrLinkLocalSourceAddress{}
+ }
+ if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) {
+ return &ip.ErrLinkLocalDestinationAddress{}
}
ttl := h.TTL()
@@ -624,7 +664,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// If the gateway processing a datagram finds the time to live field
// is zero it must discard the datagram. The gateway may also notify
// the source host via the time exceeded message.
- return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ //
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
+ return &ip.ErrTTLExceeded{}
}
if opts := h.Options(); len(opts) != 0 {
@@ -635,10 +680,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
pointer: optProblem.Pointer,
forwarding: true,
}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
- e.stats.ip.MalformedPacketsReceived.Increment()
}
- return nil // option problems are not reported locally.
+ return &ip.ErrParameterProblem{}
}
copied := copy(opts, newOpts)
if copied != len(newOpts) {
@@ -655,18 +698,44 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
}
}
+ stk := e.protocol.stack
+
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(ep.nic.ID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
ep.handleValidatedPacket(h, pkt)
return nil
}
- r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt)
+ return &ip.ErrNoRoute{}
+ default:
+ return &ip.ErrOther{Err: err}
}
defer r.Release()
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(r.NICID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
@@ -680,10 +749,28 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
- return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(newHdr).ToVectorisedView(),
- }))
+ IsForwardedPacket: true,
+ })); err.(type) {
+ case nil:
+ return nil
+ case *tcpip.ErrMessageTooLong:
+ // As per RFC 792, page 4, Destination Unreachable:
+ //
+ // Another case is when a datagram must be fragmented to be forwarded by a
+ // gateway yet the Don't Fragment flag is on. In this case the gateway must
+ // discard the datagram and may return a destination unreachable message.
+ //
+ // WriteHeaderIncludedPacket checks for the presence of the Don't Fragment bit
+ // while sending the packet and returns this error iff fragmentation is
+ // necessary and the bit is also set.
+ _ = e.protocol.returnError(&icmpReasonFragmentationNeeded{}, pkt)
+ return &ip.ErrMessageTooLong{}
+ default:
+ return &ip.ErrOther{Err: err}
+ }
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -764,6 +851,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats
+ stats.ip.ValidPacketsReceived.Increment()
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -794,11 +882,30 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
addressEndpoint.DecRef()
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
- _ = e.forwardPacket(pkt)
+ switch err := e.forwardPacket(pkt); err.(type) {
+ case nil:
+ return
+ case *ip.ErrLinkLocalSourceAddress:
+ stats.ip.Forwarding.LinkLocalSource.Increment()
+ case *ip.ErrLinkLocalDestinationAddress:
+ stats.ip.Forwarding.LinkLocalDestination.Increment()
+ case *ip.ErrTTLExceeded:
+ stats.ip.Forwarding.ExhaustedTTL.Increment()
+ case *ip.ErrNoRoute:
+ stats.ip.Forwarding.Unrouteable.Increment()
+ case *ip.ErrParameterProblem:
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ case *ip.ErrMessageTooLong:
+ stats.ip.Forwarding.PacketTooBig.Increment()
+ default:
+ panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
+ }
+ stats.ip.Forwarding.Errors.Increment()
return
}
@@ -955,8 +1062,8 @@ func (e *endpoint) Close() {
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
if err == nil {
@@ -967,8 +1074,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
}
@@ -981,8 +1088,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
loopback := e.nic.IsLoopback()
return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool {
@@ -1067,7 +1174,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1088,12 +1194,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
ids []uint32
hashIV uint32
@@ -1206,35 +1306,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 7d413c455..da9cc0ae8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -112,67 +112,103 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+type forwardedPacket struct {
+ fragments []fragmentInfo
+}
+
func TestForwarding(t *testing.T) {
const (
- nicID1 = 1
- nicID2 = 2
+ incomingNICID = 1
+ outgoingNICID = 2
randomSequence = 123
randomIdent = 42
randomTimeOffset = 0x10203040
)
- ipv4Addr1 := tcpip.AddressWithPrefix{
+ incomingIPv4Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
PrefixLen: 8,
}
- ipv4Addr2 := tcpip.AddressWithPrefix{
+ outgoingIPv4Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
PrefixLen: 8,
}
- remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4())
- remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
+ outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2")
+ remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2")
+ unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2")
+ multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0")
+ linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0")
tests := []struct {
- name string
- TTL uint8
- expectErrorICMP bool
- options header.IPv4Options
- forwardedOptions header.IPv4Options
- icmpType header.ICMPv4Type
- icmpCode header.ICMPv4Code
+ name string
+ TTL uint8
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ expectErrorICMP bool
+ ipFlags uint8
+ mtu uint32
+ payloadLength int
+ options header.IPv4Options
+ forwardedOptions header.IPv4Options
+ icmpType header.ICMPv4Type
+ icmpCode header.ICMPv4Code
+ expectPacketUnrouteableError bool
+ expectLinkLocalSourceError bool
+ expectLinkLocalDestError bool
+ expectPacketForwarded bool
+ expectedFragmentsForwarded []fragmentInfo
}{
{
name: "TTL of zero",
TTL: 0,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
expectErrorICMP: true,
icmpType: header.ICMPv4TimeExceeded,
icmpCode: header.ICMPv4TTLExceeded,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "TTL of one",
- TTL: 1,
- expectErrorICMP: false,
+ name: "TTL of one",
+ TTL: 1,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "TTL of two",
- TTL: 2,
- expectErrorICMP: false,
+ name: "TTL of two",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectErrorICMP: false,
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
},
{
- name: "four EOL options",
- TTL: 2,
- expectErrorICMP: false,
- options: header.IPv4Options{0, 0, 0, 0},
- forwardedOptions: header.IPv4Options{0, 0, 0, 0},
+ name: "four EOL options",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ expectPacketForwarded: true,
+ mtu: ipv4.MaxTotalSize,
+ options: header.IPv4Options{0, 0, 0, 0},
+ forwardedOptions: header.IPv4Options{0, 0, 0, 0},
},
{
- name: "TS type 1 full",
- TTL: 2,
+ name: "TS type 1 full",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 12, 13, 0xF1,
192, 168, 1, 12,
@@ -183,8 +219,11 @@ func TestForwarding(t *testing.T) {
icmpCode: header.ICMPv4UnusedCode,
},
{
- name: "TS type 0",
- TTL: 2,
+ name: "TS type 0",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 24, 21, 0x00,
1, 2, 3, 4,
@@ -201,10 +240,14 @@ func TestForwarding(t *testing.T) {
13, 14, 15, 16,
0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
},
+ expectPacketForwarded: true,
},
{
- name: "end of options list",
- TTL: 2,
+ name: "end of options list",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: ipv4.MaxTotalSize,
options: header.IPv4Options{
68, 12, 13, 0x11,
192, 168, 1, 12,
@@ -220,11 +263,89 @@ func TestForwarding(t *testing.T) {
0, 0, 0, // 7 bytes unknown option removed.
0, 0, 0, 0,
},
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Network unreachable",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: unreachableIPv4Addr,
+ expectErrorICMP: true,
+ mtu: ipv4.MaxTotalSize,
+ icmpType: header.ICMPv4DstUnreachable,
+ icmpCode: header.ICMPv4NetUnreachable,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Multicast destination",
+ TTL: 2,
+ destAddr: multicastIPv4Addr,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Link local destination",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: linkLocalIPv4Addr,
+ expectLinkLocalDestError: true,
+ },
+ {
+ name: "Link local source",
+ TTL: 2,
+ sourceAddr: linkLocalIPv4Addr,
+ destAddr: remoteIPv4Addr2,
+ expectLinkLocalSourceError: true,
+ },
+ {
+ name: "Fragmentation needed and DF set",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ ipFlags: header.IPv4FlagDontFragment,
+ // We've picked this MTU because it is:
+ //
+ // 1) Greater than the minimum MTU that IPv4 hosts are required to process
+ // (576 bytes). As per RFC 1812, Section 4.3.2.3:
+ //
+ // The ICMP datagram SHOULD contain as much of the original datagram as
+ // possible without the length of the ICMP datagram exceeding 576 bytes.
+ //
+ // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a
+ // complete ICMP packet on the incoming endpoint (and make assertions about
+ // it).
+ //
+ // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose
+ // size exceeds the MTU.
+ mtu: 1000,
+ payloadLength: 1004,
+ expectErrorICMP: true,
+ icmpType: header.ICMPv4DstUnreachable,
+ icmpCode: header.ICMPv4FragmentationNeeded,
+ },
+ {
+ name: "Fragmentation needed and DF not set",
+ TTL: 2,
+ sourceAddr: remoteIPv4Addr1,
+ destAddr: remoteIPv4Addr2,
+ mtu: 1000,
+ payloadLength: 1004,
+ expectPacketForwarded: true,
+ // Combined, these fragments have length of 1012 octets, which is equal to
+ // the length of the payload (1004 octets), plus the length of the ICMP
+ // header (8 octets).
+ expectedFragmentsForwarded: []fragmentInfo{
+ // The first fragment has a length of the greatest multiple of 8 which is
+ // less than or equal to to `mtu - header.IPv4MinimumSize`.
+ {offset: 0, payloadSize: uint16(976), more: true},
+ // The next fragment holds the rest of the packet.
+ {offset: uint16(976), payloadSize: 36, more: false},
+ },
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock := faketime.NewManualClock()
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
@@ -236,46 +357,52 @@ func TestForwarding(t *testing.T) {
clock.Advance(time.Millisecond * randomTimeOffset)
// We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ incomingEndpoint := channel.New(1, test.mtu, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
- ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err)
+ incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
- e2 := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ expectedEmittedPacketCount := 1
+ if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount {
+ expectedEmittedPacketCount = len(test.expectedFragmentsForwarded)
+ }
+ outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr)
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
- ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err)
+ outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
- Destination: ipv4Addr1.Subnet(),
- NIC: nicID1,
+ Destination: incomingIPv4Addr.Subnet(),
+ NIC: incomingNICID,
},
{
- Destination: ipv4Addr2.Subnet(),
- NIC: nicID2,
+ Destination: outgoingIPv4Addr.Subnet(),
+ NIC: outgoingNICID,
},
})
- if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err)
}
ipHeaderLength := header.IPv4MinimumSize + len(test.options)
if ipHeaderLength > header.IPv4MaximumHeaderSize {
t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
}
- totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(totalLen))
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpHeaderLength := header.ICMPv4MinimumSize
+ totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength
+ hdr := buffer.NewPrependable(totalLength)
+ hdr.Prepend(test.payloadLength)
+ icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
icmp.SetIdent(randomIdent)
icmp.SetSequence(randomSequence)
icmp.SetType(header.ICMPv4Echo)
@@ -284,11 +411,12 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(^header.Checksum(icmp, 0))
ip := header.IPv4(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
+ TotalLength: uint16(totalLength),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: test.TTL,
- SrcAddr: remoteIPv4Addr1,
- DstAddr: remoteIPv4Addr2,
+ SrcAddr: test.sourceAddr,
+ DstAddr: test.destAddr,
+ Flags: test.ipFlags,
})
if len(test.options) != 0 {
ip.SetHeaderLength(uint8(ipHeaderLength))
@@ -303,51 +431,122 @@ func TestForwarding(t *testing.T) {
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
- e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+
+ reply, ok := incomingEndpoint.Read()
if test.expectErrorICMP {
- reply, ok := e1.Read()
if !ok {
t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
}
+ // We expect the ICMP packet to contain as much of the original packet as
+ // possible up to a limit of 576 bytes, split between payload, IP header,
+ // and ICMP header.
+ expectedICMPPayloadLength := func() int {
+ maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize
+ maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength
+ if len(hdr.View()) > maxICMPPayloadLength {
+ return maxICMPPayloadLength
+ }
+ return len(hdr.View())
+ }
+
checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv4Addr1.Address),
- checker.DstAddr(remoteIPv4Addr1),
+ checker.SrcAddr(incomingIPv4Addr.Address),
+ checker.DstAddr(test.sourceAddr),
checker.TTL(ipv4.DefaultTTL),
checker.ICMPv4(
checker.ICMPv4Checksum(),
checker.ICMPv4Type(test.icmpType),
checker.ICMPv4Code(test.icmpCode),
- checker.ICMPv4Payload([]byte(hdr.View())),
+ checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
),
)
+ } else if ok {
+ t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
+ }
- if n := e2.Drain(); n != 0 {
- t.Fatalf("got e2.Drain() = %d, want = 0", n)
+ if test.expectPacketForwarded {
+ if len(test.expectedFragmentsForwarded) != 0 {
+ fragmentedPackets := []*stack.PacketBuffer{}
+ for i := 0; i < len(test.expectedFragmentsForwarded); i++ {
+ reply, ok = outgoingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo fragment through outgoing NIC")
+ }
+ fragmentedPackets = append(fragmentedPackets, reply.Pkt)
+ }
+
+ // The forwarded packet's TTL will have been decremented.
+ ipHeader := header.IPv4(requestPkt.NetworkHeader().View())
+ ipHeader.SetTTL(ipHeader.TTL() - 1)
+
+ // Forwarded packets have available header bytes equalling the sum of the
+ // maximum IP header size and the maximum size allocated for link layer
+ // headers. In this case, no size is allocated for link layer headers.
+ expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize
+ if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
+ t.Error(err)
+ }
+ } else {
+ reply, ok = outgoingEndpoint.Read()
+ if !ok {
+ t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ }
+
+ checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(test.sourceAddr),
+ checker.DstAddr(test.destAddr),
+ checker.TTL(test.TTL-1),
+ checker.IPv4Options(test.forwardedOptions),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Payload(nil),
+ ),
+ )
}
} else {
- reply, ok := e2.Read()
- if !ok {
- t.Fatal("expected ICMP Echo packet through outgoing NIC")
+ if reply, ok = outgoingEndpoint.Read(); ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
+ }
+ }
+ boolToInt := func(val bool) uint64 {
+ if val {
+ return 1
}
+ return 0
+ }
- checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(remoteIPv4Addr1),
- checker.DstAddr(remoteIPv4Addr2),
- checker.TTL(test.TTL-1),
- checker.IPv4Options(test.forwardedOptions),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Code(header.ICMPv4UnusedCode),
- checker.ICMPv4Payload(nil),
- ),
- )
+ if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
+ }
- if n := e1.Drain(); n != 0 {
- t.Fatalf("got e1.Drain() = %d, want = 0", n)
- }
+ if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want {
+ t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want)
}
})
}
@@ -1170,13 +1369,25 @@ func TestIPv4Sanity(t *testing.T) {
}
}
-// comparePayloads compared the contents of all the packets against the contents
-// of the source packet.
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+// compareFragments compares the contents of a set of fragmented packets against
+// the contents of a source packet.
+//
+// If withIPHeader is set to true, we will validate the fragmented packets' IP
+// headers against the source packet's IP header. If set to false, we validate
+// the fragmented packets' IP headers against each other.
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error {
// Make a complete array of the sourcePacket packet.
- source := header.IPv4(packets[0].NetworkHeader().View())
+ var source header.IPv4
vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
- source = append(source, vv.ToView()...)
+
+ // If the packet to be fragmented contains an IPv4 header, use that header for
+ // validating fragment headers. Else, use the header of the first fragment.
+ if withIPHeader {
+ source = header.IPv4(vv.ToView())
+ } else {
+ source = header.IPv4(packets[0].NetworkHeader().View())
+ source = append(source, vv.ToView()...)
+ }
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -1199,12 +1410,12 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
if got := fragmentIPHeader.TransportProtocol(); got != proto {
return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
}
- if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve {
- return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
- }
if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
}
+ if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes)
+ }
if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
}
@@ -1220,6 +1431,14 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
+
+ // If we are validating against the original IP header, we should exclude the
+ // ID field, which will only be set fo fragmented packets.
+ if withIPHeader {
+ fragmentIPHeader.SetID(0)
+ fragmentIPHeader.SetChecksum(0)
+ fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum())
+ }
if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
@@ -1348,7 +1567,7 @@ func TestFragmentationWritePacket(t *testing.T) {
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
t.Error(err)
}
})
@@ -1429,7 +1648,7 @@ func TestFragmentationWritePackets(t *testing.T) {
}
fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
- if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
t.Error(err)
}
})
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
index a637f9d50..d1f9e3cf5 100644
--- a/pkg/tcpip/network/ipv4/stats_test.go
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -19,8 +19,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
var _ stack.NetworkInterface = (*testInterface)(nil)
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index db998e83e..f99cbf8f3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -45,6 +45,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 1319db32b..307e1972d 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
return
}
+ // Keep needed information before trimming header.
+ p := hdr.TransportProtocol()
+ dstAddr := hdr.DestinationAddress()
+
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().TrimFront(header.IPv6MinimumSize)
- p := hdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6MinimumSize)
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// because they don't have the transport headers.
return
}
+ p = fragHdr.TransportProtocol()
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().TrimFront(header.IPv6FragmentHeaderSize)
- p = fragHdr.TransportProtocol()
+ pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
}
- e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
+ e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
}
// getLinkAddrOption searches NDP options for a given link address option using
@@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize)
networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
if err != nil {
networkMTU = 0
}
+ pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
@@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
received.invalid.Increment()
return
}
- pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize)
- switch header.ICMPv6(hdr).Code() {
+ code := header.ICMPv6(hdr).Code()
+ pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch code {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
case header.ICMPv6PortUnreachable:
@@ -741,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- stack := e.protocol.stack
-
- // Is the networking stack operating as a router?
- if !stack.Forwarding(ProtocolNumber) {
- // ... No, silently drop the packet.
+ if !e.Forwarding() {
received.routerOnlyPacketsDroppedByHost.Increment()
return
}
@@ -951,6 +951,19 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
// icmpReason is a marker interface for IPv6 specific ICMP errors.
type icmpReason interface {
isICMPReason()
+ // isForwarding indicates whether or not the error arose while attempting to
+ // forward a packet.
+ isForwarding() bool
+ // respondToMulticast indicates whether this error falls under the exception
+ // outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are two
+ // exceptions to this rule: (1) the Packet Too Big Message (Section 3.2) to
+ // allow Path MTU discovery to work for IPv6 multicast, and (2) the Parameter
+ // Problem Message, Code 2 (Section 3.4) reporting an unrecognized IPv6
+ // option (see Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondsToMulticast() bool
}
// icmpReasonParameterProblem is an error during processing of extension headers
@@ -958,18 +971,6 @@ type icmpReason interface {
type icmpReasonParameterProblem struct {
code header.ICMPv6Code
- // respondToMulticast indicates that we are sending a packet that falls under
- // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
- //
- // (e.3) A packet destined to an IPv6 multicast address. (There are
- // two exceptions to this rule: (1) the Packet Too Big Message
- // (Section 3.2) to allow Path MTU discovery to work for IPv6
- // multicast, and (2) the Parameter Problem Message, Code 2
- // (Section 3.4) reporting an unrecognized IPv6 option (see
- // Section 4.2 of [IPv6]) that has the Option Type highest-
- // order two bits set to 10).
- respondToMulticast bool
-
// pointer is defined in the RFC 4443 setion 3.4 which reads:
//
// Pointer Identifies the octet offset within the invoking packet
@@ -979,9 +980,20 @@ type icmpReasonParameterProblem struct {
// packet if the field in error is beyond what can fit
// in the maximum size of an ICMPv6 error message.
pointer uint32
+
+ forwarding bool
+
+ respondToMulticast bool
}
func (*icmpReasonParameterProblem) isICMPReason() {}
+func (p *icmpReasonParameterProblem) isForwarding() bool {
+ return p.forwarding
+}
+
+func (p *icmpReasonParameterProblem) respondsToMulticast() bool {
+ return p.respondToMulticast
+}
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
@@ -989,12 +1001,76 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+func (*icmpReasonPortUnreachable) isForwarding() bool {
+ return false
+}
+
+func (*icmpReasonPortUnreachable) respondsToMulticast() bool {
+ return false
+}
+
+// icmpReasonNetUnreachable is an error where no route can be found to the
+// network of the final destination.
+type icmpReasonNetUnreachable struct{}
+
+func (*icmpReasonNetUnreachable) isICMPReason() {}
+
+func (*icmpReasonNetUnreachable) isForwarding() bool {
+ // If we hit a Network Unreachable error, then we also know we are
+ // operating as a router. As per RFC 4443 section 3.1:
+ //
+ // If the reason for the failure to deliver is lack of a matching
+ // entry in the forwarding node's routing table, the Code field is
+ // set to 0 (Network Unreachable).
+ return true
+}
+
+func (*icmpReasonNetUnreachable) respondsToMulticast() bool {
+ return false
+}
+
+// icmpReasonFragmentationNeeded is an error where a packet is to big to be sent
+// out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message.
+type icmpReasonPacketTooBig struct{}
+
+func (*icmpReasonPacketTooBig) isICMPReason() {}
+
+func (*icmpReasonPacketTooBig) isForwarding() bool {
+ // If we hit a Packet Too Big error, then we know we are operating as a router.
+ // As per RFC 4443 section 3.2:
+ //
+ // A Packet Too Big MUST be sent by a router in response to a packet that it
+ // cannot forward because the packet is larger than the MTU of the outgoing
+ // link.
+ return true
+}
+
+func (*icmpReasonPacketTooBig) respondsToMulticast() bool {
+ return true
+}
+
// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in
// transit to its final destination, as per RFC 4443 section 3.3.
type icmpReasonHopLimitExceeded struct{}
func (*icmpReasonHopLimitExceeded) isICMPReason() {}
+func (*icmpReasonHopLimitExceeded) isForwarding() bool {
+ // If we hit a Hop Limit Exceeded error, then we know we are operating
+ // as a router. As per RFC 4443 section 3.3:
+ //
+ // If a router receives a packet with a Hop Limit of zero, or if a
+ // router decrements a packet's Hop Limit to zero, it MUST discard
+ // the packet and originate an ICMPv6 Time Exceeded message with Code
+ // 0 to the source of the packet. This indicates either a routing
+ // loop or too small an initial Hop Limit value.
+ return true
+}
+
+func (*icmpReasonHopLimitExceeded) respondsToMulticast() bool {
+ return false
+}
+
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
// the reception of the first-arriving fragment of that packet.
@@ -1002,6 +1078,14 @@ type icmpReasonReassemblyTimeout struct{}
func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+func (*icmpReasonReassemblyTimeout) isForwarding() bool {
+ return false
+}
+
+func (*icmpReasonReassemblyTimeout) respondsToMulticast() bool {
+ return false
+}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
@@ -1030,25 +1114,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
// Section 4.2 of [IPv6]) that has the Option Type highest-
// order two bits set to 10).
//
- var allowResponseToMulticast bool
- if reason, ok := reason.(*icmpReasonParameterProblem); ok {
- allowResponseToMulticast = reason.respondToMulticast
- }
-
+ allowResponseToMulticast := reason.respondsToMulticast()
isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst)
if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any {
return nil
}
- // If we hit a Hop Limit Exceeded error, then we know we are operating as a
- // router. As per RFC 4443 section 3.3:
- //
- // If a router receives a packet with a Hop Limit of zero, or if a
- // router decrements a packet's Hop Limit to zero, it MUST discard the
- // packet and originate an ICMPv6 Time Exceeded message with Code 0 to
- // the source of the packet. This indicates either a routing loop or
- // too small an initial Hop Limit value.
- //
// If we are operating as a router, do not use the packet's destination
// address as the response's source address as we should not own the
// destination address of a packet we are forwarding.
@@ -1058,7 +1129,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
// packet as "multicast addresses must not be used as source addresses in IPv6
// packets", as per RFC 4291 section 2.7.
localAddr := origIPHdrDst
- if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast {
+ if reason.isForwarding() || isOrigDstMulticast {
localAddr = ""
}
// Even if we were able to receive a packet from some remote, we may not have
@@ -1147,6 +1218,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.dstUnreachable
+ case *icmpReasonNetUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
+ counter = sent.dstUnreachable
+ case *icmpReasonPacketTooBig:
+ icmpHdr.SetType(header.ICMPv6PacketTooBig)
+ icmpHdr.SetCode(header.ICMPv6UnusedCode)
+ counter = sent.packetTooBig
case *icmpReasonHopLimitExceeded:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index e457be3cf..040cd4bc8 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -673,8 +673,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
+ if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
+ }
}
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f7510c243..95e11ac51 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -63,6 +63,11 @@ const (
buckets = 2048
)
+const (
+ forwardingDisabled = 0
+ forwardingEnabled = 1
+)
+
// policyTable is the default policy table defined in RFC 6724 section 2.1.
//
// A more human-readable version:
@@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 {
var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
+var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
var _ stack.AddressableEndpoint = (*endpoint)(nil)
var _ stack.NetworkEndpoint = (*endpoint)(nil)
@@ -187,6 +193,12 @@ type endpoint struct {
// Must be accessed using atomic operations.
enabled uint32
+ // forwarding is set to forwardingEnabled when the endpoint has forwarding
+ // enabled and forwardingDisabled when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
mu struct {
sync.RWMutex
@@ -405,27 +417,39 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
}
}
-// transitionForwarding transitions the endpoint's forwarding status to
-// forwarding.
+// Forwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) Forwarding() bool {
+ return atomic.LoadUint32(&e.forwarding) == forwardingEnabled
+}
+
+// setForwarding sets the forwarding status for the endpoint.
//
-// Must only be called when the forwarding status changes.
-func (e *endpoint) transitionForwarding(forwarding bool) {
+// Returns true if the forwarding status was updated.
+func (e *endpoint) setForwarding(v bool) bool {
+ forwarding := uint32(forwardingDisabled)
+ if v {
+ forwarding = forwardingEnabled
+ }
+
+ return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding
+}
+
+// SetForwarding implements stack.ForwardingNetworkEndpoint.
+func (e *endpoint) SetForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.setForwarding(forwarding) {
+ return
+ }
+
allRoutersGroups := [...]tcpip.Address{
header.IPv6AllRoutersInterfaceLocalMulticastAddress,
header.IPv6AllRoutersLinkLocalMulticastAddress,
header.IPv6AllRoutersSiteLocalMulticastAddress,
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
if forwarding {
- // When transitioning into an IPv6 router, host-only state (NDP discovered
- // routers, discovered on-link prefixes, and auto-generated addresses) is
- // cleaned up/invalidated and NDP router solicitations are stopped.
- e.mu.ndp.stopSolicitingRouters()
- e.mu.ndp.cleanupState(true /* hostOnly */)
-
// As per RFC 4291 section 2.8:
//
// A router is required to recognize all addresses that a host is
@@ -449,28 +473,19 @@ func (e *endpoint) transitionForwarding(forwarding bool) {
panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err))
}
}
-
- return
- }
-
- for _, g := range allRoutersGroups {
- switch err := e.leaveGroupLocked(g).(type) {
- case nil:
- case *tcpip.ErrBadLocalAddress:
- // The endpoint may have already left the multicast group.
- default:
- panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err))
+ } else {
+ for _, g := range allRoutersGroups {
+ switch err := e.leaveGroupLocked(g).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err))
+ }
}
}
- // When transitioning into an IPv6 host, NDP router solicitations are
- // started if the endpoint is enabled.
- //
- // If the endpoint is not currently enabled, routers will be solicited when
- // the endpoint becomes enabled (if it is still a host).
- if e.Enabled() {
- e.mu.ndp.startSolicitingRouters()
- }
+ e.mu.ndp.forwardingChanged(forwarding)
}
// Enable implements stack.NetworkEndpoint.
@@ -552,17 +567,7 @@ func (e *endpoint) Enable() tcpip.Error {
e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
}
- // If we are operating as a router, then do not solicit routers since we
- // won't process the RAs anyway.
- //
- // Routers do not process Router Advertisements (RA) the same way a host
- // does. That is, routers do not learn from RAs (e.g. on-link prefixes
- // and default routers). Therefore, soliciting RAs from other routers on
- // a link is unnecessary for routers.
- if !e.protocol.Forwarding() {
- e.mu.ndp.startSolicitingRouters()
- }
-
+ e.mu.ndp.startSolicitingRouters()
return nil
}
@@ -613,7 +618,7 @@ func (e *endpoint) disableLocked() {
return true
})
- e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.mu.ndp.cleanupState()
// The endpoint may have already left the multicast group.
switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) {
@@ -786,6 +791,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
}
if packetMustBeFragmented(pkt, networkMTU) {
+ if pkt.NetworkPacketInfo.IsForwardedPacket {
+ // As per RFC 2460, section 4.5:
+ // Unlike IPv4, fragmentation in IPv6 is performed only by source nodes,
+ // not by routers along a packet's delivery path.
+ return &tcpip.ErrMessageTooLong{}
+ }
sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
@@ -928,16 +939,19 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
h := header.IPv6(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
- if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) {
- // As per RFC 4291 section 2.5.6,
- //
- // Routers must not forward any packets with Link-Local source or
- // destination addresses to other links.
- return nil
+ // As per RFC 4291 section 2.5.6,
+ //
+ // Routers must not forward any packets with Link-Local source or
+ // destination addresses to other links.
+ if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) {
+ return &ip.ErrLinkLocalSourceAddress{}
+ }
+ if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) {
+ return &ip.ErrLinkLocalDestinationAddress{}
}
hopLimit := h.HopLimit()
@@ -949,21 +963,56 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// packet and originate an ICMPv6 Time Exceeded message with Code 0 to
// the source of the packet. This indicates either a routing loop or
// too small an initial Hop Limit value.
- return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ //
+ // We return the original error rather than the result of returning
+ // the ICMP packet because the original error is more relevant to
+ // the caller.
+ _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
+ return &ip.ErrTTLExceeded{}
}
+ stk := e.protocol.stack
+
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(ep.nic.ID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
ep.handleValidatedPacket(h, pkt)
return nil
}
- r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ // Check extension headers for any errors requiring action during forwarding.
+ if err := e.processExtensionHeaders(h, pkt, true /* forwarding */); err != nil {
+ return &ip.ErrParameterProblem{}
+ }
+
+ r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable:
+ // We return the original error rather than the result of returning the
+ // ICMP packet because the original error is more relevant to the caller.
+ _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt)
+ return &ip.ErrNoRoute{}
+ default:
+ return &ip.ErrOther{Err: err}
}
defer r.Release()
+ inNicName := stk.FindNICNameFromID(e.nic.ID())
+ outNicName := stk.FindNICNameFromID(r.NICID())
+ if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ // iptables is telling us to drop the packet.
+ e.stats.ip.IPTablesForwardDropped.Increment()
+ return nil
+ }
+
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
@@ -975,10 +1024,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: buffer.View(newHdr).ToVectorisedView(),
- }))
+ IsForwardedPacket: true,
+ })); err.(type) {
+ case nil:
+ return nil
+ case *tcpip.ErrMessageTooLong:
+ // As per RFC 4443, section 3.2:
+ // A Packet Too Big MUST be sent by a router in response to a packet that
+ // it cannot forward because the packet is larger than the MTU of the
+ // outgoing link.
+ _ = e.protocol.returnError(&icmpReasonPacketTooBig{}, pkt)
+ return &ip.ErrMessageTooLong{}
+ default:
+ return &ip.ErrOther{Err: err}
+ }
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
@@ -1059,6 +1121,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
+ stats.ValidPacketsReceived.Increment()
+
srcAddr := h.SourceAddress()
dstAddr := h.DestinationAddress()
@@ -1075,15 +1139,54 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
- if !e.protocol.Forwarding() {
+ if !e.Forwarding() {
stats.InvalidDestinationAddressesReceived.Increment()
return
}
+ switch err := e.forwardPacket(pkt); err.(type) {
+ case nil:
+ return
+ case *ip.ErrLinkLocalSourceAddress:
+ e.stats.ip.Forwarding.LinkLocalSource.Increment()
+ case *ip.ErrLinkLocalDestinationAddress:
+ e.stats.ip.Forwarding.LinkLocalDestination.Increment()
+ case *ip.ErrTTLExceeded:
+ e.stats.ip.Forwarding.ExhaustedTTL.Increment()
+ case *ip.ErrNoRoute:
+ e.stats.ip.Forwarding.Unrouteable.Increment()
+ case *ip.ErrParameterProblem:
+ e.stats.ip.Forwarding.ExtensionHeaderProblem.Increment()
+ case *ip.ErrMessageTooLong:
+ e.stats.ip.Forwarding.PacketTooBig.Increment()
+ default:
+ panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt))
+ }
+ e.stats.ip.Forwarding.Errors.Increment()
+ return
+ }
- _ = e.forwardPacket(pkt)
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and need not be forwarded.
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ // iptables is telling us to drop the packet.
+ stats.IPTablesInputDropped.Increment()
return
}
+ // Any returned error is only useful for terminating execution early, but
+ // we have nothing left to do, so we can drop it.
+ _ = e.processExtensionHeaders(h, pkt, false /* forwarding */)
+}
+
+// processExtensionHeaders processes the extension headers in the given packet.
+// Returns an error if the processing of a header failed or if the packet should
+// be discarded.
+func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffer, forwarding bool) error {
+ stats := e.stats.ip
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
+
// Create a VV to parse the packet. We don't plan to modify anything here.
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
@@ -1094,15 +1197,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
vv.AppendViews(pkt.Data().Views())
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
- // iptables filtering. All packets that reach here are intended for
- // this machine and need not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
- // iptables is telling us to drop the packet.
- stats.IPTablesInputDropped.Increment()
- return
- }
-
var (
hasFragmentHeader bool
routerAlert *header.IPv6RouterAlertOption
@@ -1115,22 +1209,41 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
extHdr, done, err := it.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
}
+ // As per RFC 8200, section 4:
+ //
+ // Extension headers (except for the Hop-by-Hop Options header) are
+ // not processed, inserted, or deleted by any node along a packet's
+ // delivery path until the packet reaches the node identified in the
+ // Destination Address field of the IPv6 header.
+ //
+ // Furthermore, as per RFC 8200 section 4.1, the Hop By Hop extension
+ // header is restricted to appear first in the list of extension headers.
+ //
+ // Therefore, we can immediately return once we hit any header other
+ // than the Hop-by-Hop header while forwarding a packet.
+ if forwarding {
+ if _, ok := extHdr.(header.IPv6HopByHopOptionsExtHdr); !ok {
+ return nil
+ }
+ }
+
switch extHdr := extHdr.(type) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
if previousHeaderStart != 0 {
_ = e.protocol.returnError(&icmpReasonParameterProblem{
- code: header.ICMPv6UnknownHeader,
- pointer: previousHeaderStart,
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found Hop-by-Hop header = %#v with non-zero previous header offset = %d", extHdr, previousHeaderStart)
}
optsIt := extHdr.Iter()
@@ -1139,7 +1252,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
opt, done, err := optsIt.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1154,7 +1267,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// There MUST only be one option of this type, regardless of
// value, per Hop-by-Hop header.
stats.MalformedPacketsReceived.Increment()
- return
+ return fmt.Errorf("found multiple Router Alert options (%#v, %#v)", opt, routerAlert)
}
routerAlert = opt
stats.OptionRouterAlertReceived.Increment()
@@ -1162,10 +1275,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
switch opt.UnknownAction() {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
- return
+ return fmt.Errorf("found unknown Hop-by-Hop header option = %#v with discard action", opt)
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
if header.IsV6MulticastAddress(dstAddr) {
- return
+ return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt)
}
fallthrough
case header.IPv6OptionUnknownActionDiscardSendICMP:
@@ -1180,10 +1293,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt)
default:
- panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %#v", opt))
}
}
}
@@ -1205,8 +1319,13 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: it.ParseOffset(),
+ // For the sake of consistency, we're using the value of `forwarding`
+ // here, even though it should always be false if we've reached this
+ // point. If `forwarding` is true here, we're executing undefined
+ // behavior no matter what.
+ forwarding: forwarding,
}, pkt)
- return
+ return fmt.Errorf("found unrecognized routing type with non-zero segments left in header = %#v", extHdr)
}
case header.IPv6FragmentExtHdr:
@@ -1241,7 +1360,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if err != nil {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1269,7 +1388,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
default:
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return fmt.Errorf("known extension header = %#v present after fragment header in a non-initial fragment", lastHdr)
}
}
@@ -1278,7 +1397,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// Drop the packet as it's marked as a fragment but has no payload.
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return fmt.Errorf("fragment has no payload")
}
// As per RFC 2460 Section 4.5:
@@ -1296,7 +1415,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6ErroneousHeader,
pointer: header.IPv6PayloadLenOffset,
}, pkt)
- return
+ return fmt.Errorf("found fragment length = %d that is not a multiple of 8 octets", fragmentPayloadLen)
}
// The packet is a fragment, let's try to reassemble it.
@@ -1310,14 +1429,15 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// Parameter Problem, Code 0, message should be sent to the source of
// the fragment, pointing to the Fragment Offset field of the fragment
// packet.
- if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
+ lengthAfterReassembly := int(start) + fragmentPayloadLen
+ if lengthAfterReassembly > header.IPv6MaximumPayloadSize {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: fragmentFieldOffset,
}, pkt)
- return
+ return fmt.Errorf("determined that reassembled packet length = %d would exceed allowed length = %d", lengthAfterReassembly, header.IPv6MaximumPayloadSize)
}
// Note that pkt doesn't have its transport header set after reassembly,
@@ -1339,7 +1459,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
if err != nil {
stats.MalformedPacketsReceived.Increment()
stats.MalformedFragmentsReceived.Increment()
- return
+ return err
}
if ready {
@@ -1361,7 +1481,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
opt, done, err := optsIt.Next()
if err != nil {
stats.MalformedPacketsReceived.Increment()
- return
+ return err
}
if done {
break
@@ -1372,10 +1492,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
switch opt.UnknownAction() {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
- return
+ return fmt.Errorf("found unknown destination header option = %#v with discard action", opt)
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
if header.IsV6MulticastAddress(dstAddr) {
- return
+ return fmt.Errorf("found unknown destination header option %#v with discard action", opt)
}
fallthrough
case header.IPv6OptionUnknownActionDiscardSendICMP:
@@ -1392,9 +1512,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
}, pkt)
- return
+ return fmt.Errorf("found unknown destination header option %#v with discard action", opt)
default:
- panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %#v", opt))
}
}
@@ -1402,13 +1522,19 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
+ // Calculate the number of octets parsed from data. We want to remove all
+ // the data except the unparsed portion located at the end, which its size
+ // is extHdr.Buf.Size().
+ trim := pkt.Data().Size() - extHdr.Buf.Size()
+
// For unfragmented packets, extHdr still contains the transport header.
// Get rid of it.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
- extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
- pkt.Data().Replace(extHdr.Buf)
+ trim += pkt.TransportHeader().View().Size()
+
+ pkt.Data().DeleteFront(trim)
stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
@@ -1425,6 +1551,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
_ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
+ return fmt.Errorf("destination port unreachable")
case stack.TransportPacketProtocolUnreachable:
// As per RFC 8200 section 4. (page 7):
// Extension headers are numbered from IANA IP Protocol Numbers
@@ -1456,6 +1583,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
code: header.ICMPv6UnknownHeader,
pointer: prevHdrIDOffset,
}, pkt)
+ return fmt.Errorf("transport protocol unreachable")
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -1469,6 +1597,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
}
}
+ return nil
}
// Close cleans up resources associated with the endpoint.
@@ -1490,8 +1619,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
}
@@ -1532,8 +1661,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
@@ -1610,8 +1739,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
// AcquireAssignedAddress implements stack.AddressableEndpoint.
func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
}
@@ -1833,7 +1962,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
return &e.stats.localStats
}
-var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1858,12 +1986,6 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- // forwarding is set to 1 when the protocol has forwarding enabled and 0
- // when it is disabled.
- //
- // Must be accessed using atomic operations.
- forwarding uint32
-
fragmentation *fragmentation.Fragmentation
}
@@ -2038,35 +2160,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
-// Forwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) Forwarding() bool {
- return uint8(atomic.LoadUint32(&p.forwarding)) == 1
-}
-
-// setForwarding sets the forwarding status for the protocol.
-//
-// Returns true if the forwarding status was updated.
-func (p *protocol) setForwarding(v bool) bool {
- if v {
- return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
- }
- return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
-}
-
-// SetForwarding implements stack.ForwardingNetworkProtocol.
-func (p *protocol) SetForwarding(v bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if !p.setForwarding(v) {
- return
- }
-
- for _, ep := range p.mu.eps {
- ep.transitionForwarding(v)
- }
-}
-
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 40a793d6b..afc6c3547 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -31,8 +31,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
+ iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -2603,7 +2604,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
+ ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -2802,9 +2803,9 @@ func TestFragmentationWritePacket(t *testing.T) {
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt.Clone()
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -2858,7 +2859,7 @@ func TestFragmentationWritePackets(t *testing.T) {
insertAfter: 1,
},
}
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+ tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -2868,14 +2869,14 @@ func TestFragmentationWritePackets(t *testing.T) {
for i := 0; i < test.insertBefore; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt
pkts.PushBack(pkt.Clone())
for i := 0; i < test.insertAfter; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
@@ -2980,8 +2981,8 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -3003,52 +3004,289 @@ func TestFragmentationErrors(t *testing.T) {
func TestForwarding(t *testing.T) {
const (
- nicID1 = 1
- nicID2 = 2
+ incomingNICID = 1
+ outgoingNICID = 2
randomSequence = 123
randomIdent = 42
)
- ipv6Addr1 := tcpip.AddressWithPrefix{
+ incomingIPv6Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10::1").To16()),
PrefixLen: 64,
}
- ipv6Addr2 := tcpip.AddressWithPrefix{
+ outgoingIPv6Addr := tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("11::1").To16()),
PrefixLen: 64,
}
+ multicastIPv6Addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("ff00::").To16()),
+ PrefixLen: 64,
+ }
+
remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
+ unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16())
+ linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16())
tests := []struct {
- name string
- TTL uint8
- expectErrorICMP bool
+ name string
+ extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker)
+ TTL uint8
+ expectErrorICMP bool
+ expectPacketForwarded bool
+ payloadLength int
+ countUnrouteablePackets uint64
+ sourceAddr tcpip.Address
+ destAddr tcpip.Address
+ icmpType header.ICMPv6Type
+ icmpCode header.ICMPv6Code
+ expectPacketUnrouteableError bool
+ expectLinkLocalSourceError bool
+ expectLinkLocalDestError bool
+ expectExtensionHeaderError bool
}{
{
name: "TTL of zero",
TTL: 0,
expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6TimeExceeded,
+ icmpCode: header.ICMPv6HopLimitExceeded,
},
{
name: "TTL of one",
TTL: 1,
expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6TimeExceeded,
+ icmpCode: header.ICMPv6HopLimitExceeded,
},
{
- name: "TTL of two",
- TTL: 2,
- expectErrorICMP: false,
+ name: "TTL of two",
+ TTL: 2,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ },
+ {
+ name: "TTL of three",
+ TTL: 3,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ },
+ {
+ name: "Max TTL",
+ TTL: math.MaxUint8,
+ expectPacketForwarded: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ },
+ {
+ name: "Network unreachable",
+ TTL: 2,
+ expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: unreachableIPv6Addr,
+ icmpType: header.ICMPv6DstUnreachable,
+ icmpCode: header.ICMPv6NetworkUnreachable,
+ expectPacketUnrouteableError: true,
+ },
+ {
+ name: "Multicast destination",
+ TTL: 2,
+ countUnrouteablePackets: 1,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Link local destination",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: linkLocalIPv6Addr,
+ expectLinkLocalDestError: true,
+ },
+ {
+ name: "Link local source",
+ TTL: 2,
+ sourceAddr: linkLocalIPv6Addr,
+ destAddr: remoteIPv6Addr2,
+ expectLinkLocalSourceError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option skippable action",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption()))
+ },
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard action",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action (unicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action (multicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6ParamProblem,
+ icmpCode: header.ICMPv6UnknownOption,
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
},
{
- name: "TTL of three",
- TTL: 3,
- expectErrorICMP: false,
+ name: "Hopbyhop with router alert option",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 0,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+ }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)))
+ },
+ expectPacketForwarded: true,
+ },
+ {
+ name: "Hopbyhop with two router alert options",
+ TTL: 2,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
+ return []byte{
+ nextHdr, 1,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+
+ // Router Alert option.
+ 5, 2, 0, 0, 0, 0,
+ }, hopByHopExtHdrID, nil
+ },
+ expectExtensionHeaderError: true,
+ },
+ {
+ name: "Can't fragment",
+ TTL: 2,
+ payloadLength: header.IPv6MinimumMTU + 1,
+ expectErrorICMP: true,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: remoteIPv6Addr2,
+ icmpType: header.ICMPv6PacketTooBig,
+ icmpCode: header.ICMPv6UnusedCode,
},
{
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectErrorICMP: false,
+ name: "Can't fragment multicast",
+ TTL: 2,
+ payloadLength: header.IPv6MinimumMTU + 1,
+ sourceAddr: remoteIPv6Addr1,
+ destAddr: multicastIPv6Addr.Address,
+ expectErrorICMP: true,
+ icmpType: header.ICMPv6PacketTooBig,
+ icmpCode: header.ICMPv6UnusedCode,
},
}
@@ -3059,41 +3297,60 @@ func TestForwarding(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
})
// We expect at most a single packet in response to our ICMP Echo Request.
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
- ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1}
- if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err)
+ incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
- ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2}
- if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err)
+ outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
- Destination: ipv6Addr1.Subnet(),
- NIC: nicID1,
+ Destination: incomingIPv6Addr.Subnet(),
+ NIC: incomingNICID,
+ },
+ {
+ Destination: outgoingIPv6Addr.Subnet(),
+ NIC: outgoingNICID,
},
{
- Destination: ipv6Addr2.Subnet(),
- NIC: nicID2,
+ Destination: multicastIPv6Addr.Subnet(),
+ NIC: outgoingNICID,
},
})
- if err := s.SetForwarding(ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err)
+ if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
}
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize)
- icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ transportProtocol := header.ICMPv6ProtocolNumber
+ extHdrBytes := []byte{}
+ extHdrChecker := checker.IPv6ExtHdr()
+ if test.extHdr != nil {
+ nextHdrID := hopByHopExtHdrID
+ extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber))
+ transportProtocol = tcpip.TransportProtocolNumber(nextHdrID)
+ }
+ extHdrLen := len(extHdrBytes)
+
+ ipHeaderLength := header.IPv6MinimumSize
+ icmpHeaderLength := header.ICMPv6MinimumSize
+ totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
+ hdr := buffer.NewPrependable(totalLength)
+ hdr.Prepend(test.payloadLength)
+ icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
+
icmp.SetIdent(randomIdent)
icmp.SetSequence(randomSequence)
icmp.SetType(header.ICMPv6EchoRequest)
@@ -3101,52 +3358,72 @@ func TestForwarding(t *testing.T) {
icmp.SetChecksum(0)
icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmp,
- Src: remoteIPv6Addr1,
- Dst: remoteIPv6Addr2,
+ Src: test.sourceAddr,
+ Dst: test.destAddr,
}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
+ ip := header.IPv6(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: header.ICMPv6ProtocolNumber,
+ PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
+ TransportProtocol: transportProtocol,
HopLimit: test.TTL,
- SrcAddr: remoteIPv6Addr1,
- DstAddr: remoteIPv6Addr2,
+ SrcAddr: test.sourceAddr,
+ DstAddr: test.destAddr,
})
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
- e1.InjectInbound(ProtocolNumber, requestPkt)
+ incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt)
+
+ reply, ok := incomingEndpoint.Read()
if test.expectErrorICMP {
- reply, ok := e1.Read()
if !ok {
- t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC")
+ t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
+ }
+
+ // As per RFC 4443, page 9:
+ //
+ // The returned ICMP packet will contain as much of invoking packet
+ // as possible without the ICMPv6 packet exceeding the minimum IPv6
+ // MTU.
+ expectedICMPPayloadLength := func() int {
+ maxICMPPayloadLength := header.IPv6MinimumMTU - ipHeaderLength - icmpHeaderLength
+ if len(hdr.View()) > maxICMPPayloadLength {
+ return maxICMPPayloadLength
+ }
+ return len(hdr.View())
}
checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(ipv6Addr1.Address),
- checker.DstAddr(remoteIPv6Addr1),
+ checker.SrcAddr(incomingIPv6Addr.Address),
+ checker.DstAddr(test.sourceAddr),
checker.TTL(DefaultTTL),
checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6TimeExceeded),
- checker.ICMPv6Code(header.ICMPv6HopLimitExceeded),
- checker.ICMPv6Payload([]byte(hdr.View())),
+ checker.ICMPv6Type(test.icmpType),
+ checker.ICMPv6Code(test.icmpCode),
+ checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])),
),
)
- if n := e2.Drain(); n != 0 {
+ if n := outgoingEndpoint.Drain(); n != 0 {
t.Fatalf("got e2.Drain() = %d, want = 0", n)
}
- } else {
- reply, ok := e2.Read()
+ } else if ok {
+ t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
+ }
+
+ reply, ok = outgoingEndpoint.Read()
+ if test.expectPacketForwarded {
if !ok {
t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
}
- checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
- checker.SrcAddr(remoteIPv6Addr1),
- checker.DstAddr(remoteIPv6Addr2),
+ checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())),
+ checker.SrcAddr(test.sourceAddr),
+ checker.DstAddr(test.destAddr),
checker.TTL(test.TTL-1),
+ extHdrChecker,
checker.ICMPv6(
checker.ICMPv6Type(header.ICMPv6EchoRequest),
checker.ICMPv6Code(header.ICMPv6UnusedCode),
@@ -3154,9 +3431,46 @@ func TestForwarding(t *testing.T) {
),
)
- if n := e1.Drain(); n != 0 {
+ if n := incomingEndpoint.Drain(); n != 0 {
t.Fatalf("got e1.Drain() = %d, want = 0", n)
}
+ } else if ok {
+ t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
+ }
+
+ boolToInt := func(val bool) uint64 {
+ if val {
+ return 1
+ }
+ return 0
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want {
+ t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpType == header.ICMPv6PacketTooBig); got != want {
+ t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index d6e0a81a6..f0ff111c5 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -48,7 +48,7 @@ const (
// defaultHandleRAs is the default configuration for whether or not to
// handle incoming Router Advertisements as a host.
- defaultHandleRAs = true
+ defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled
// defaultDiscoverDefaultRouters is the default configuration for
// whether or not to discover default routers from incoming Router
@@ -301,10 +301,60 @@ type NDPDispatcher interface {
OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
}
+var _ fmt.Stringer = HandleRAsConfiguration(0)
+
+// HandleRAsConfiguration enumerates when RAs may be handled.
+type HandleRAsConfiguration int
+
+const (
+ // HandlingRAsDisabled indicates that Router Advertisements will not be
+ // handled.
+ HandlingRAsDisabled HandleRAsConfiguration = iota
+
+ // HandlingRAsEnabledWhenForwardingDisabled indicates that router
+ // advertisements will only be handled when forwarding is disabled.
+ HandlingRAsEnabledWhenForwardingDisabled
+
+ // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always
+ // be handled, even when forwarding is enabled.
+ HandlingRAsAlwaysEnabled
+)
+
+// String implements fmt.Stringer.
+func (c HandleRAsConfiguration) String() string {
+ switch c {
+ case HandlingRAsDisabled:
+ return "HandlingRAsDisabled"
+ case HandlingRAsEnabledWhenForwardingDisabled:
+ return "HandlingRAsEnabledWhenForwardingDisabled"
+ case HandlingRAsAlwaysEnabled:
+ return "HandlingRAsAlwaysEnabled"
+ default:
+ return fmt.Sprintf("HandleRAsConfiguration(%d)", c)
+ }
+}
+
+// enabled returns true iff Router Advertisements may be handled given the
+// specified forwarding status.
+func (c HandleRAsConfiguration) enabled(forwarding bool) bool {
+ switch c {
+ case HandlingRAsDisabled:
+ return false
+ case HandlingRAsEnabledWhenForwardingDisabled:
+ return !forwarding
+ case HandlingRAsAlwaysEnabled:
+ return true
+ default:
+ panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c))
+ }
+}
+
// NDPConfigurations is the NDP configurations for the netstack.
type NDPConfigurations struct {
// The number of Router Solicitation messages to send when the IPv6 endpoint
// becomes enabled.
+ //
+ // Ignored unless configured to handle Router Advertisements.
MaxRtrSolicitations uint8
// The amount of time between transmitting Router Solicitation messages.
@@ -318,8 +368,9 @@ type NDPConfigurations struct {
// Must be greater than or equal to 0s.
MaxRtrSolicitationDelay time.Duration
- // HandleRAs determines whether or not Router Advertisements are processed.
- HandleRAs bool
+ // HandleRAs is the configuration for when Router Advertisements should be
+ // handled.
+ HandleRAs HandleRAsConfiguration
// DiscoverDefaultRouters determines whether or not default routers are
// discovered from Router Advertisements, as per RFC 4861 section 6. This
@@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a protocol-wide configuration, so we check the
// protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
// packets.
- if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
+ ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment()
return
}
@@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t
delete(tempAddrs, tempAddr)
}
-// removeSLAACAddresses removes all SLAAC addresses.
-//
-// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
-//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
- linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- var linkLocalPrefixes int
- for prefix, state := range ndp.slaacPrefixes {
- // RFC 4862 section 5 states that routers are also expected to generate a
- // link-local address so we do not invalidate them if we are cleaning up
- // host-only state.
- if keepLinkLocal && prefix == linkLocalSubnet {
- linkLocalPrefixes++
- continue
- }
-
- ndp.invalidateSLAACPrefix(prefix, state)
- }
-
- if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
- panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
- }
-}
-
// cleanupState cleans up ndp's state.
//
-// If hostOnly is true, then only host-specific state is cleaned up.
-//
// This function invalidates all discovered on-link prefixes, discovered
// routers, and auto-generated addresses.
//
-// If hostOnly is true, then the link-local auto-generated address aren't
-// invalidated as routers are also expected to generate a link-local address.
-//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupState(hostOnly bool) {
- ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
+func (ndp *ndpState) cleanupState() {
+ for prefix, state := range ndp.slaacPrefixes {
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }
for prefix := range ndp.onLinkPrefixes {
ndp.invalidateOnLinkPrefix(prefix)
@@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
// 6.3.7. If routers are already being solicited, this function does nothing.
//
+// If ndp is not configured to handle Router Advertisements, routers will not
+// be solicited as there is no point soliciting routers if we don't handle their
+// advertisements.
+//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
if ndp.rtrSolicitTimer.timer != nil {
@@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
+ if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) {
+ return
+ }
+
// Calculate the random delay before sending our first RS, as per RFC
// 4861 section 6.3.7.
var delay time.Duration
@@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() {
}
}
+// forwardingChanged handles a change in forwarding configuration.
+//
+// If transitioning to a host, router solicitation will be started. Otherwise,
+// router solicitation will be stopped if NDP is not configured to handle RAs
+// as a router.
+//
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) forwardingChanged(forwarding bool) {
+ if forwarding {
+ if ndp.configs.HandleRAs.enabled(forwarding) {
+ return
+ }
+
+ ndp.stopSolicitingRouters()
+ return
+ }
+
+ // Solicit routers when transitioning to a host.
+ //
+ // If the endpoint is not currently enabled, routers will be solicited when
+ // the endpoint becomes enabled (if it is still a host).
+ if ndp.ep.Enabled() {
+ ndp.startSolicitingRouters()
+ }
+}
+
// stopSolicitingRouters stops soliciting routers. If routers are not currently
// being solicited, this function does nothing.
//
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 52b9a200c..234e34952 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -732,15 +732,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
}
func TestNDPValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- return s, ep
- }
+ const nicID = 1
handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
var extHdrs header.IPv6ExtHdrSerializer
@@ -865,6 +857,11 @@ func TestNDPValidation(t *testing.T) {
},
}
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
+
for _, typ := range types {
for _, isRouter := range []bool{false, true} {
name := typ.name
@@ -875,13 +872,35 @@ func TestNDPValidation(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, test := range subTests {
t.Run(test.name, func(t *testing.T) {
- s, ep := setup(t)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ })
if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
+ if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
+ }
}
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatal("cannot find network endpoint instance for IPv6")
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }})
+
stats := s.Stats().ICMP.V6.PacketsReceived
invalid := stats.Invalid
routerOnly := stats.RouterOnlyPacketsDroppedByHost
@@ -906,12 +925,12 @@ func TestNDPValidation(t *testing.T) {
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ t.Errorf("got invalid.Value() = %d, want = 0", got)
}
- // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ // Should initially not have dropped any packets.
if got := routerOnly.Value(); got != 0 {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ t.Errorf("got routerOnly.Value() = %d, want = 0", got)
}
if t.Failed() {
@@ -931,18 +950,18 @@ func TestNDPValidation(t *testing.T) {
want = 1
}
if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
+ t.Errorf("got invalid.Value() = %d, want = %d", got, want)
}
want = 0
if test.valid && !isRouter && typ.routerOnly {
- // RouterOnlyPacketsReceivedByHost count should have increased.
+ // Router only packets are expected to be dropped when operating
+ // as a host.
want = 1
}
if got := routerOnly.Value(); got != want {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
+ t.Errorf("got routerOnly.Value() = %d, want = %d", got, want)
}
-
})
}
})
diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go
index c2758352f..2f18f60e8 100644
--- a/pkg/tcpip/network/ipv6/stats.go
+++ b/pkg/tcpip/network/ipv6/stats.go
@@ -29,6 +29,10 @@ type Stats struct {
// ICMP holds ICMPv6 statistics.
ICMP tcpip.ICMPv6Stats
+
+ // UnhandledRouterAdvertisements is the number of Router Advertisements that
+ // were observed but not handled.
+ UnhandledRouterAdvertisements *tcpip.StatCounter
}
// IsNetworkEndpointStats implements stack.NetworkEndpointStats.